def __init__(self):
        opt = TestOptions().parse()  # get test options
        # init pygame
        pygame.init()
        self.size = (256, 256)
        self.screen = pygame.display.set_mode(self.size)
        self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64)
        self.time = pygame.time.get_ticks()
        #self.surface_test = pygame.surfarray.make_surface()
        self.screen.fill(pygame.Color(255, 255, 255))
        pygame.display.flip()

        self.model = CycleGANModel(opt)
        self.model.setup(opt)
        #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9)
        #self.net = init_net(net, 'normal', 0.02, [])
        impath = os.getcwd() + "/datasets/bird/testA/514.png"

        image = pygame.image.load(impath)
Ejemplo n.º 2
0
def main():
    opt = ArchTrainOptions().parse()
    torch.cuda.manual_seed(12345)

    opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    cycle_gan = CycleGANModel(opt)
    cycle_gan.setup(opt)
    cycle_gan.set_arch(opt.arch, opt.n_resnet - 1)

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'train_steps': 0
    }

    # for i, data in tqdm(enumerate(dataset)):
    #     cycle_gan.set_input(data)
    #     cycle_gan.forward()
    #     cycle_gan.compute_visuals()
    #     save_current_results(opt, cycle_gan.get_current_visuals(), i)

    cyclgan_train(opt, cycle_gan, dataset, writer_dict)
Ejemplo n.º 3
0
    def gather_options(self):
        """Initialize our parser with basic options(only once).
        Add additional model-specific and dataset-specific options.
        These options are defined in the <modify_commandline_options> function
        in model and dataset classes.
        """
        if not self.initialized:  # check if it has been initialized
            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        parser.parse_known_args()

        # modify model-related parser options
        parser = CycleGANModel.modify_commandline_options(parser, self.isTrain)
        parser.parse_known_args()  # parse again with new defaults

        # save and return the parser
        self.parser = parser
        return parser.parse_args()
Ejemplo n.º 4
0
def cyclgan_train(opt, cycle_gan: CycleGANModel, train_loader, writer_dict):
    cycle_gan.train()

    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in trange(opt.epoch_count,
                        opt.n_epochs + opt.n_epochs_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.n_epochs +
                                                opt.n_epochs_decay)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if (total_iters + 1) % opt.save_latest_freq == 0:
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                cycle_gan.save_networks(save_suffix)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            cycle_gan.save_networks(epoch)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1
        cycle_gan.update_learning_rate()
class DrawingCanvas:
    def __init__(self):
        opt = TestOptions().parse()  # get test options
        # init pygame
        pygame.init()
        self.size = (256, 256)
        self.screen = pygame.display.set_mode(self.size)
        self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64)
        self.time = pygame.time.get_ticks()
        #self.surface_test = pygame.surfarray.make_surface()
        self.screen.fill(pygame.Color(255, 255, 255))
        pygame.display.flip()

        self.model = CycleGANModel(opt)
        self.model.setup(opt)
        #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9)
        #self.net = init_net(net, 'normal', 0.02, [])
        impath = os.getcwd() + "/datasets/bird/testA/514.png"

        image = pygame.image.load(impath)
        #self.screen.blit(image, (0, 0))

    """
    Method 'game_loop' will be executed every frame to drive
    the display and handling of events in the background. 
    In Processing this is done behind the screen. Don't 
    change this, unless you know what you are doing.
    """

    def game_loop(self):
        current_time = pygame.time.get_ticks()
        delta_time = current_time - self.time
        self.time = current_time
        self.handle_events()
        self.update_game(delta_time)
        self.draw_components()

    """
    Method 'update_game' is there to update the state of variables 
    and objects from frame to frame.
    """

    def update_game(self, dt):
        pass

    """
    Method 'draw_components' is similar is meant to contain 
    everything that draws one frame. It is similar to method
    void draw() in Processing. Put all draw calls here. Leave all
    updates in method 'update'
    """

    def draw_components(self):
        #self.screen.fill([255, 255, 255])
        #pygame.display.flip()
        pass

    def reset(self):
        pass

    """
    Method 'handle_event' loop over all the event types and 
    handles them accordingly. 
    In Processing this is done behind the screen. Don't 
    change this, unless you know what you are doing.
    """

    def handle_events(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()
            if event.type == pygame.KEYDOWN:
                self.handle_key_down(event)
            if event.type == pygame.KEYUP:
                self.handle_key_up(event)
            if event.type == pygame.MOUSEMOTION:
                self.handle_mouse_motion(event)
            if event.type == pygame.MOUSEBUTTONDOWN:
                self.handle_mouse_pressed(event)
            if event.type == pygame.MOUSEBUTTONUP:
                self.handle_mouse_released(event)

    """
    This method will store a currently pressed buttons 
    in list 'keyboard_handler.pressed'.
    """

    def handle_key_down(self, event):
        pass

    """
    This method will remove a released button 
    from list 'keyboard_handler.pressed'.
    """

    def handle_key_up(self, event):
        pass

    """
    Similar to void mouseMoved() in Processing
    """

    def handle_mouse_motion(self, event):
        #print("test: ",pygame.mouse.get_pressed()[0])
        if pygame.mouse.get_pressed()[0]:
            pos = pygame.mouse.get_pos()
            pygame.display.update(
                pygame.draw.ellipse(self.screen, (0, 0, 0), [pos, [5, 5]]))
            #print(pos)
            self.screen.blit(self.screen, (0, 0))

    """
    Similar to void mousePressed() in Processing
    """

    def handle_mouse_pressed(self, event):
        pos = pygame.mouse.get_pos()
        pygame.display.update(
            pygame.draw.rect(self.screen, (0, 0, 0), [pos, [5, 5]]))
        #(pos)
        self.screen.blit(self.screen, (0, 0))

    """
    Similar to void mouseReleased() in Processing
    """

    def handle_mouse_released(self, event):
        #pygame.display.flip()
        test = pygame.surfarray.array3d(self.screen)

        print(test.shape)
        #test = test.T
        test = test.transpose(1, 0, 2)
        print(test.shape)
        #string_image = pygame.image.tostring(self.screen, 'RGBA')
        #temp_surf = pygame.image.fromstring(string_image, (512, 512), 'RGB')
        #tmp_arr = pygame.surfarray.array2d(temp_surf)

        compose = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize(256, interpolation=Image.CUBIC),
            transforms.ToTensor()
        ])

        test_tensor = compose(test).unsqueeze(0)
        #plt.figure()
        #plt.imshow(test)
        #plt.show()

        print(test_tensor.size())
        #test = compose(test)
        #self.net.set_input(test)
        self.model.set_input(test_tensor)
        result = self.model.forward()
        result = self.model.get_generated()
        print("Result", result)

        resultT = result.squeeze(0)
        resultT[resultT < 0] = 0
        im = transforms.ToPILImage()(resultT).convert("RGB")

        test = result.squeeze(0)
        print(test.size())
        result = result.detach().numpy()
        result = np.squeeze(result, axis=0)
        result = result.transpose(1, 2, 0)
        print(result.shape)
        results = result[:] * 255
        #results[result < 0]
        print(im)
        print(im.size)
        plt.imshow(im)
        plt.show()

    def lab2rgb(self, L, AB):
        """Convert an Lab tensor image to a RGB numpy output
        Parameters:
            L  (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
            AB (2-channel tensor array):  ab channel images (range: [-1, 1], torch tensor array)
        Returns:
            rgb (RGB numpy image): rgb output images  (range: [0, 255], numpy array)
        """
        AB2 = AB * 110.0
        L2 = (L + 1.0) * 50.0
        Lab = torch.cat([L2, AB2], dim=1)
        Lab = Lab[0].data.cpu().float().numpy()
        Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
        rgb = color.lab2rgb(Lab) * 255
        return rgb
Ejemplo n.º 6
0
from models.cycle_gan_with_distillation import CycleGANModelWithDistillation
from models.cycle_gan_model import CycleGANModel
if __name__ == '__main__':
    opt = TrainOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.num_threads = 0  # test code only supports num_threads = 1
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    opt2 = Namespace(**vars(opt))
    opt2.name = 'monet2photo_pretrained'
    opt2.isTrain = False
    teacher = CycleGANModel(opt2)
    teacher.isTeacher = True
    opt.continue_train = True
    opt2.continue_train = True
    teacher.setup(
        opt2)  # regular setup: load and print networks; create schedulers
    opt.netG = 'resnet_3blocks'
    opt.results_dir = 'results'
    model = CycleGANModelWithDistillation(
        opt, teacher)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(
        opt.phase, opt.epoch))  # define the website directory
    if opt.load_iter > 0:  # load_iter is 0 by default
Ejemplo n.º 7
0
import numpy as np

from visdom import Visdom
viz = Visdom()

assert viz.check_connection()
viz.close()

opt = TrainOptions().parse()
save_opt(opt)

data_loader = DataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

model = CycleGANModel()
model.initialize(opt)
visualizer = Visualizer(opt)

if __name__ == '__main__':

    total_steps = 0
    sparse_c_loss_points, sparse_c_loss_avr_points = [], []

    win_sparse_C = viz.line(X=torch.zeros((1, )),
                            Y=torch.zeros((1, )),
                            name="win_sparse_C")

    for epoch in range(1, opt.epoch + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
Ejemplo n.º 8
0
from utilSet import html
from models.cycle_gan_model import CycleGANModel
from utilSet.visualizer import Visualizer
from config import TestOptions
from data.dataset import DataLoader
import ntpath

opt = TestOptions().parse()
opt.nThreads = 1  # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = DataLoader(opt)
dataset = data_loader.load_data()
model = CycleGANModel()
model.initialize(opt)
visualizer = Visualizer(opt)

if __name__ == '__main__':
    root_dir = os.path.join(opt.result_root_dir, opt.variable)
    web_dir = os.path.join(root_dir, opt.variable_value, opt.phase)
    webpage = html.HTML(web_dir,
                        'Experiment = GAN2C, Phase = test, Epoch = latest')
    # test
    for i, data in enumerate(dataset):
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()

        img_path = model.get_image_paths()
Ejemplo n.º 9
0
def main():
    opt = SearchOptions().parse()
    torch.cuda.manual_seed(12345)

    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    start_search_iter = 0
    cur_stage = 1

    delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \
                       [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)]

    opt.max_search_iter = sum(delta_grow_steps)
    grow_steps = [
        sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps))
    ][1:]

    grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps)

    if opt.load_path:
        print(f'=> resuming from {opt.load_path}')
        assert os.path.exists(opt.load_path)
        checkpoint_file = os.path.join(opt.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        start_search_iter = checkpoint["search_iter"]
        opt.path_helper = checkpoint['path_helper']

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

        cycle_gan.load_from_state(checkpoint["cycle_gan"])
        cycle_controller.load_from_state(checkpoint["cycle_controller"])

    else:
        opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'controller_steps': start_search_iter * opt.ctrl_step,
        'train_steps': start_search_iter * opt.shared_epoch
    }

    g_loss_history = RunningStats(opt.dynamic_reset_window)
    d_loss_history = RunningStats(opt.dynamic_reset_window)

    dynamic_reset = None
    for search_iter in tqdm(
            range(int(start_search_iter), int(opt.max_search_iter))):
        tqdm.write(f"<start search iteration {search_iter}>")
        cycle_controller.reset()

        if search_iter in grow_steps:
            cur_stage = grow_ctrler.cur_stage(search_iter) + 1
            tqdm.write(f'=> grow to stage {cur_stage}')
            prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A(
            )
            prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B(
            )

            del cycle_controller

            cycle_controller = CycleControllerModel(opt, cur_stage)
            cycle_controller.setup(opt)
            cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B,
                                 prev_archs_A, prev_archs_B)

        dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller,
                                      dataset, g_loss_history, d_loss_history,
                                      writer_dict)

        controller_train(opt, cycle_gan, cycle_controller, writer_dict)

        if dynamic_reset:
            tqdm.write('re-initialize share GAN')
            del cycle_gan
            cycle_gan = CycleGANModel(opt)
            cycle_gan.setup(opt)

        save_checkpoint(
            {
                'cur_stage':
                cur_stage,
                'search_iter':
                search_iter + 1,
                'cycle_gan':
                cycle_gan.save_networks(epoch=search_iter),
                'cycle_controller':
                cycle_controller.save_networks(epoch=search_iter),
                'path_helper':
                opt.path_helper
            }, False, opt.path_helper['ckpt_path'])

    final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A()
    final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B()
    print(f"discovered archs: {final_archs_A}")
    print(f"discovered archs: {final_archs_B}")
Ejemplo n.º 10
0
def cyclgan_train(opt, cycle_gan: CycleGANModel,
                  cycle_controller: CycleControllerModel, train_loader,
                  g_loss_history: RunningStats, d_loss_history: RunningStats,
                  writer_dict):
    cycle_gan.train()
    cycle_controller.eval()

    dynamic_reset = False
    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in range(opt.shared_epoch):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_controller.forward()

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            g_loss_history.push(cycle_gan.loss_G.item())
            d_loss_history.push(cycle_gan.loss_D_A.item() +
                                cycle_gan.loss_D_B.item())

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if g_loss_history.is_full():
                if g_loss_history.get_var() < opt.dynamic_reset_threshold \
                        or d_loss_history.get_var() < opt.dynamic_reset_threshold:
                    dynamic_reset = True
                    tqdm.write("=> dynamic resetting triggered")
                    g_loss_history.clear()
                    d_loss_history.clear()
                    return dynamic_reset

            if (
                    total_iters + 1
            ) % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                # cycle_gan.save_networks(train_steps)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            # cycle_gan.save_networks(train_steps)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1

    return dynamic_reset
Ejemplo n.º 11
0
def controller_train(opt, cycle_gan: CycleGANModel,
                     cycle_controller: CycleControllerModel, writer_dict):
    writer = writer_dict['writer']

    # train mode
    cycle_controller.train()

    # eval mode
    cycle_gan.eval()
    iter_start_time = time.time()
    for i in range(0, opt.ctrl_step):
        controller_step = writer_dict['controller_steps']

        cycle_controller.step_A()
        cycle_controller.step_B()

        if (i + 1) % opt.print_freq_controller == 0:
            losses = cycle_controller.get_current_losses()
            t_comp = (time.time() - iter_start_time)
            iter_start_time = time.time()
            message = "Cont: [Ep: %d/%d]" % (
                i, opt.ctrl_step) + "[{}][{}]".format(cycle_controller.arch_A,
                                                      cycle_controller.arch_B)
            message += "[time: %.3f]" % (t_comp)
            for k, v in losses.items():
                message += '[%s: %.3f]' % (k, v)
            tqdm.write(message)
        # write
        writer.add_scalars(
            'Controller/loss', {
                "A": cycle_controller.loss_A.item(),
                "B": cycle_controller.loss_B.item()
            }, controller_step)

        writer.add_scalars(
            'Controller/discriminator', {
                "A": cycle_controller.loss_D_A.item(),
                "B": cycle_controller.loss_D_B.item()
            }, controller_step)
        writer.add_scalars(
            'Controller/inception_score', {
                "A": cycle_controller.loss_IS_A.item(),
                "B": cycle_controller.loss_IS_B.item()
            }, controller_step)

        writer.add_scalars('Controller/adv', {
            "A": cycle_controller.loss_adv_A,
            "B": cycle_controller.loss_adv_B
        }, controller_step)
        writer.add_scalars(
            'Controller/entropy', {
                "A": cycle_controller.loss_entropy_A,
                "B": cycle_controller.loss_entropy_B
            }, controller_step)
        writer.add_scalars(
            'Controller/reward', {
                "A": cycle_controller.loss_reward_A,
                "B": cycle_controller.loss_reward_B
            }, controller_step)

        writer_dict['controller_steps'] = controller_step + 1
Ejemplo n.º 12
0
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models.cycle_gan_model import CycleGANModel
from util.util import create_log_txt, print_current_losses

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.

    model = CycleGANModel(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    log_file = create_log_txt(opt)  # create a log file for training progress
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_learning_rate(
        )  # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
Ejemplo n.º 13
0
    ngf=64,
    no_dropout=True,
    no_flip=True,
    norm="instance",
    ntest=float("inf"),
    num_test=100,
    num_threads=0,
    output_nc=3,
    phase="test",
    preprocess="no_preprocessing",
    results_dir="./results/",
    serial_batches=True,
    suffix="",
    verbose=False,
)
model = CycleGANModel(opt).netG_A
model.load_state_dict(torch.load(model_fp))

preprocess = get_transform(opt)


class SingleImageDataset(torch.utils.data.Dataset):
    def __init__(self, *args, **kwargs):
        img = kwargs.pop("img")
        super().__init__(*args, **kwargs)
        img = preprocess(img)
        self.img = img

    def __getitem__(self, i):
        return self.img
Ejemplo n.º 14
0
# -*- coding:utf-8 -*-
import time
from config import TrainOptions
from models.cycle_gan_model import CycleGANModel
from utilSet.visualizer import Visualizer, save_opt
from data.dataset import DataLoader

opt = TrainOptions().parse()
save_opt(opt)
data_loader = DataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

model = CycleGANModel()
model.initialize(opt)
visualizer = Visualizer(opt)

if __name__ == '__main__':
    total_steps = 0
    for epoch in range(1, opt.epoch + 1):
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            visualizer.reset()
            total_steps += 1
            epoch_iter += 1
            model.set_input(data)
            model.optimize_parameters()
Ejemplo n.º 15
0
from options.test_options import TestOptions
from data import create_dataset
from models.cycle_gan_model import CycleGANModel
from util.util import save_images, create_results_dir
from util import html

if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.num_threads = 0  # test code only supports num_threads = 0
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    model = CycleGANModel(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    # create results dir
    image_dir = create_results_dir(opt)
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()  # get image paths