Exemplo n.º 1
0
def main():
    opt = ArchTrainOptions().parse()
    torch.cuda.manual_seed(12345)

    opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    cycle_gan = CycleGANModel(opt)
    cycle_gan.setup(opt)
    cycle_gan.set_arch(opt.arch, opt.n_resnet - 1)

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'train_steps': 0
    }

    # for i, data in tqdm(enumerate(dataset)):
    #     cycle_gan.set_input(data)
    #     cycle_gan.forward()
    #     cycle_gan.compute_visuals()
    #     save_current_results(opt, cycle_gan.get_current_visuals(), i)

    cyclgan_train(opt, cycle_gan, dataset, writer_dict)
class DrawingCanvas:
    def __init__(self):
        opt = TestOptions().parse()  # get test options
        # init pygame
        pygame.init()
        self.size = (256, 256)
        self.screen = pygame.display.set_mode(self.size)
        self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64)
        self.time = pygame.time.get_ticks()
        #self.surface_test = pygame.surfarray.make_surface()
        self.screen.fill(pygame.Color(255, 255, 255))
        pygame.display.flip()

        self.model = CycleGANModel(opt)
        self.model.setup(opt)
        #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9)
        #self.net = init_net(net, 'normal', 0.02, [])
        impath = os.getcwd() + "/datasets/bird/testA/514.png"

        image = pygame.image.load(impath)
        #self.screen.blit(image, (0, 0))

    """
    Method 'game_loop' will be executed every frame to drive
    the display and handling of events in the background. 
    In Processing this is done behind the screen. Don't 
    change this, unless you know what you are doing.
    """

    def game_loop(self):
        current_time = pygame.time.get_ticks()
        delta_time = current_time - self.time
        self.time = current_time
        self.handle_events()
        self.update_game(delta_time)
        self.draw_components()

    """
    Method 'update_game' is there to update the state of variables 
    and objects from frame to frame.
    """

    def update_game(self, dt):
        pass

    """
    Method 'draw_components' is similar is meant to contain 
    everything that draws one frame. It is similar to method
    void draw() in Processing. Put all draw calls here. Leave all
    updates in method 'update'
    """

    def draw_components(self):
        #self.screen.fill([255, 255, 255])
        #pygame.display.flip()
        pass

    def reset(self):
        pass

    """
    Method 'handle_event' loop over all the event types and 
    handles them accordingly. 
    In Processing this is done behind the screen. Don't 
    change this, unless you know what you are doing.
    """

    def handle_events(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()
            if event.type == pygame.KEYDOWN:
                self.handle_key_down(event)
            if event.type == pygame.KEYUP:
                self.handle_key_up(event)
            if event.type == pygame.MOUSEMOTION:
                self.handle_mouse_motion(event)
            if event.type == pygame.MOUSEBUTTONDOWN:
                self.handle_mouse_pressed(event)
            if event.type == pygame.MOUSEBUTTONUP:
                self.handle_mouse_released(event)

    """
    This method will store a currently pressed buttons 
    in list 'keyboard_handler.pressed'.
    """

    def handle_key_down(self, event):
        pass

    """
    This method will remove a released button 
    from list 'keyboard_handler.pressed'.
    """

    def handle_key_up(self, event):
        pass

    """
    Similar to void mouseMoved() in Processing
    """

    def handle_mouse_motion(self, event):
        #print("test: ",pygame.mouse.get_pressed()[0])
        if pygame.mouse.get_pressed()[0]:
            pos = pygame.mouse.get_pos()
            pygame.display.update(
                pygame.draw.ellipse(self.screen, (0, 0, 0), [pos, [5, 5]]))
            #print(pos)
            self.screen.blit(self.screen, (0, 0))

    """
    Similar to void mousePressed() in Processing
    """

    def handle_mouse_pressed(self, event):
        pos = pygame.mouse.get_pos()
        pygame.display.update(
            pygame.draw.rect(self.screen, (0, 0, 0), [pos, [5, 5]]))
        #(pos)
        self.screen.blit(self.screen, (0, 0))

    """
    Similar to void mouseReleased() in Processing
    """

    def handle_mouse_released(self, event):
        #pygame.display.flip()
        test = pygame.surfarray.array3d(self.screen)

        print(test.shape)
        #test = test.T
        test = test.transpose(1, 0, 2)
        print(test.shape)
        #string_image = pygame.image.tostring(self.screen, 'RGBA')
        #temp_surf = pygame.image.fromstring(string_image, (512, 512), 'RGB')
        #tmp_arr = pygame.surfarray.array2d(temp_surf)

        compose = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize(256, interpolation=Image.CUBIC),
            transforms.ToTensor()
        ])

        test_tensor = compose(test).unsqueeze(0)
        #plt.figure()
        #plt.imshow(test)
        #plt.show()

        print(test_tensor.size())
        #test = compose(test)
        #self.net.set_input(test)
        self.model.set_input(test_tensor)
        result = self.model.forward()
        result = self.model.get_generated()
        print("Result", result)

        resultT = result.squeeze(0)
        resultT[resultT < 0] = 0
        im = transforms.ToPILImage()(resultT).convert("RGB")

        test = result.squeeze(0)
        print(test.size())
        result = result.detach().numpy()
        result = np.squeeze(result, axis=0)
        result = result.transpose(1, 2, 0)
        print(result.shape)
        results = result[:] * 255
        #results[result < 0]
        print(im)
        print(im.size)
        plt.imshow(im)
        plt.show()

    def lab2rgb(self, L, AB):
        """Convert an Lab tensor image to a RGB numpy output
        Parameters:
            L  (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
            AB (2-channel tensor array):  ab channel images (range: [-1, 1], torch tensor array)
        Returns:
            rgb (RGB numpy image): rgb output images  (range: [0, 255], numpy array)
        """
        AB2 = AB * 110.0
        L2 = (L + 1.0) * 50.0
        Lab = torch.cat([L2, AB2], dim=1)
        Lab = Lab[0].data.cpu().float().numpy()
        Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
        rgb = color.lab2rgb(Lab) * 255
        return rgb
Exemplo n.º 3
0
def main():
    opt = SearchOptions().parse()
    torch.cuda.manual_seed(12345)

    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    start_search_iter = 0
    cur_stage = 1

    delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \
                       [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)]

    opt.max_search_iter = sum(delta_grow_steps)
    grow_steps = [
        sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps))
    ][1:]

    grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps)

    if opt.load_path:
        print(f'=> resuming from {opt.load_path}')
        assert os.path.exists(opt.load_path)
        checkpoint_file = os.path.join(opt.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        start_search_iter = checkpoint["search_iter"]
        opt.path_helper = checkpoint['path_helper']

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

        cycle_gan.load_from_state(checkpoint["cycle_gan"])
        cycle_controller.load_from_state(checkpoint["cycle_controller"])

    else:
        opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'controller_steps': start_search_iter * opt.ctrl_step,
        'train_steps': start_search_iter * opt.shared_epoch
    }

    g_loss_history = RunningStats(opt.dynamic_reset_window)
    d_loss_history = RunningStats(opt.dynamic_reset_window)

    dynamic_reset = None
    for search_iter in tqdm(
            range(int(start_search_iter), int(opt.max_search_iter))):
        tqdm.write(f"<start search iteration {search_iter}>")
        cycle_controller.reset()

        if search_iter in grow_steps:
            cur_stage = grow_ctrler.cur_stage(search_iter) + 1
            tqdm.write(f'=> grow to stage {cur_stage}')
            prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A(
            )
            prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B(
            )

            del cycle_controller

            cycle_controller = CycleControllerModel(opt, cur_stage)
            cycle_controller.setup(opt)
            cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B,
                                 prev_archs_A, prev_archs_B)

        dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller,
                                      dataset, g_loss_history, d_loss_history,
                                      writer_dict)

        controller_train(opt, cycle_gan, cycle_controller, writer_dict)

        if dynamic_reset:
            tqdm.write('re-initialize share GAN')
            del cycle_gan
            cycle_gan = CycleGANModel(opt)
            cycle_gan.setup(opt)

        save_checkpoint(
            {
                'cur_stage':
                cur_stage,
                'search_iter':
                search_iter + 1,
                'cycle_gan':
                cycle_gan.save_networks(epoch=search_iter),
                'cycle_controller':
                cycle_controller.save_networks(epoch=search_iter),
                'path_helper':
                opt.path_helper
            }, False, opt.path_helper['ckpt_path'])

    final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A()
    final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B()
    print(f"discovered archs: {final_archs_A}")
    print(f"discovered archs: {final_archs_B}")
Exemplo n.º 4
0
 # hard-code some parameters for test
 opt.num_threads = 0  # test code only supports num_threads = 1
 opt.batch_size = 1  # test code only supports batch_size = 1
 opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
 opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
 opt.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
 dataset = create_dataset(
     opt)  # create a dataset given opt.dataset_mode and other options
 opt2 = Namespace(**vars(opt))
 opt2.name = 'monet2photo_pretrained'
 opt2.isTrain = False
 teacher = CycleGANModel(opt2)
 teacher.isTeacher = True
 opt.continue_train = True
 opt2.continue_train = True
 teacher.setup(
     opt2)  # regular setup: load and print networks; create schedulers
 opt.netG = 'resnet_3blocks'
 opt.results_dir = 'results'
 model = CycleGANModelWithDistillation(
     opt, teacher)  # create a model given opt.model and other options
 model.setup(
     opt)  # regular setup: load and print networks; create schedulers
 # create a website
 web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(
     opt.phase, opt.epoch))  # define the website directory
 if opt.load_iter > 0:  # load_iter is 0 by default
     web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter)
 print('creating web directory', web_dir)
 webpage = html.HTML(
     web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
     (opt.name, opt.phase, opt.epoch))
Exemplo n.º 5
0
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models.cycle_gan_model import CycleGANModel
from util.util import create_log_txt, print_current_losses

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.

    model = CycleGANModel(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    log_file = create_log_txt(opt)  # create a log file for training progress
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_learning_rate(
        )  # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0: