class DrawingCanvas:
    def __init__(self):
        opt = TestOptions().parse()  # get test options
        # init pygame
        pygame.init()
        self.size = (256, 256)
        self.screen = pygame.display.set_mode(self.size)
        self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64)
        self.time = pygame.time.get_ticks()
        #self.surface_test = pygame.surfarray.make_surface()
        self.screen.fill(pygame.Color(255, 255, 255))
        pygame.display.flip()

        self.model = CycleGANModel(opt)
        self.model.setup(opt)
        #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9)
        #self.net = init_net(net, 'normal', 0.02, [])
        impath = os.getcwd() + "/datasets/bird/testA/514.png"

        image = pygame.image.load(impath)
        #self.screen.blit(image, (0, 0))

    """
    Method 'game_loop' will be executed every frame to drive
    the display and handling of events in the background. 
    In Processing this is done behind the screen. Don't 
    change this, unless you know what you are doing.
    """

    def game_loop(self):
        current_time = pygame.time.get_ticks()
        delta_time = current_time - self.time
        self.time = current_time
        self.handle_events()
        self.update_game(delta_time)
        self.draw_components()

    """
    Method 'update_game' is there to update the state of variables 
    and objects from frame to frame.
    """

    def update_game(self, dt):
        pass

    """
    Method 'draw_components' is similar is meant to contain 
    everything that draws one frame. It is similar to method
    void draw() in Processing. Put all draw calls here. Leave all
    updates in method 'update'
    """

    def draw_components(self):
        #self.screen.fill([255, 255, 255])
        #pygame.display.flip()
        pass

    def reset(self):
        pass

    """
    Method 'handle_event' loop over all the event types and 
    handles them accordingly. 
    In Processing this is done behind the screen. Don't 
    change this, unless you know what you are doing.
    """

    def handle_events(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()
            if event.type == pygame.KEYDOWN:
                self.handle_key_down(event)
            if event.type == pygame.KEYUP:
                self.handle_key_up(event)
            if event.type == pygame.MOUSEMOTION:
                self.handle_mouse_motion(event)
            if event.type == pygame.MOUSEBUTTONDOWN:
                self.handle_mouse_pressed(event)
            if event.type == pygame.MOUSEBUTTONUP:
                self.handle_mouse_released(event)

    """
    This method will store a currently pressed buttons 
    in list 'keyboard_handler.pressed'.
    """

    def handle_key_down(self, event):
        pass

    """
    This method will remove a released button 
    from list 'keyboard_handler.pressed'.
    """

    def handle_key_up(self, event):
        pass

    """
    Similar to void mouseMoved() in Processing
    """

    def handle_mouse_motion(self, event):
        #print("test: ",pygame.mouse.get_pressed()[0])
        if pygame.mouse.get_pressed()[0]:
            pos = pygame.mouse.get_pos()
            pygame.display.update(
                pygame.draw.ellipse(self.screen, (0, 0, 0), [pos, [5, 5]]))
            #print(pos)
            self.screen.blit(self.screen, (0, 0))

    """
    Similar to void mousePressed() in Processing
    """

    def handle_mouse_pressed(self, event):
        pos = pygame.mouse.get_pos()
        pygame.display.update(
            pygame.draw.rect(self.screen, (0, 0, 0), [pos, [5, 5]]))
        #(pos)
        self.screen.blit(self.screen, (0, 0))

    """
    Similar to void mouseReleased() in Processing
    """

    def handle_mouse_released(self, event):
        #pygame.display.flip()
        test = pygame.surfarray.array3d(self.screen)

        print(test.shape)
        #test = test.T
        test = test.transpose(1, 0, 2)
        print(test.shape)
        #string_image = pygame.image.tostring(self.screen, 'RGBA')
        #temp_surf = pygame.image.fromstring(string_image, (512, 512), 'RGB')
        #tmp_arr = pygame.surfarray.array2d(temp_surf)

        compose = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize(256, interpolation=Image.CUBIC),
            transforms.ToTensor()
        ])

        test_tensor = compose(test).unsqueeze(0)
        #plt.figure()
        #plt.imshow(test)
        #plt.show()

        print(test_tensor.size())
        #test = compose(test)
        #self.net.set_input(test)
        self.model.set_input(test_tensor)
        result = self.model.forward()
        result = self.model.get_generated()
        print("Result", result)

        resultT = result.squeeze(0)
        resultT[resultT < 0] = 0
        im = transforms.ToPILImage()(resultT).convert("RGB")

        test = result.squeeze(0)
        print(test.size())
        result = result.detach().numpy()
        result = np.squeeze(result, axis=0)
        result = result.transpose(1, 2, 0)
        print(result.shape)
        results = result[:] * 255
        #results[result < 0]
        print(im)
        print(im.size)
        plt.imshow(im)
        plt.show()

    def lab2rgb(self, L, AB):
        """Convert an Lab tensor image to a RGB numpy output
        Parameters:
            L  (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
            AB (2-channel tensor array):  ab channel images (range: [-1, 1], torch tensor array)
        Returns:
            rgb (RGB numpy image): rgb output images  (range: [0, 255], numpy array)
        """
        AB2 = AB * 110.0
        L2 = (L + 1.0) * 50.0
        Lab = torch.cat([L2, AB2], dim=1)
        Lab = Lab[0].data.cpu().float().numpy()
        Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
        rgb = color.lab2rgb(Lab) * 255
        return rgb
Ejemplo n.º 2
0
def cyclgan_train(opt, cycle_gan: CycleGANModel, train_loader, writer_dict):
    cycle_gan.train()

    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in trange(opt.epoch_count,
                        opt.n_epochs + opt.n_epochs_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.n_epochs +
                                                opt.n_epochs_decay)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if (total_iters + 1) % opt.save_latest_freq == 0:
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                cycle_gan.save_networks(save_suffix)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            cycle_gan.save_networks(epoch)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1
        cycle_gan.update_learning_rate()
Ejemplo n.º 3
0
        web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter)
    print('creating web directory', web_dir)
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.epoch))
    #if opt.eval:
    # model.eval()
    for i, data in enumerate(dataset):
        if i > 10:
            exit()
        #if i >= opt.num_test:  # only apply our model to opt.num_test images.
        # break
        #model.set_input(data)  # unpack data from data loader
        #model.test()         # run inference

        teacher.set_input(data)
        teacher.test()
        model.set_input(
            data)  # unpack data from dataset and apply preprocessing
        model.optimize_parameters()

        model.compute_visuals()
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()  # get image paths
        print(img_path)
        #save_images(webpage, visuals, img_path)

        #visuals = model.get_current_visuals()  # get image results
        #img_path = model.get_image_paths()     # get image paths
        #if i % 5 == 0:  # save images to an HTML file
        #  print('processing (%04d)-th image... %s' % (i, img_path))
Ejemplo n.º 4
0
    win_sparse_C = viz.line(X=torch.zeros((1, )),
                            Y=torch.zeros((1, )),
                            name="win_sparse_C")

    for epoch in range(1, opt.epoch + 1):
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):

            visualizer.reset()
            total_steps += 1
            epoch_iter += 1

            model.set_input(data)
            model.optimize_parameters()

            img_path = model.get_image_paths()
            short_path = ntpath.basename(''.join(img_path))
            '''
            # 决定每轮输出显示哪一张图片
            '''
            if epoch_iter == opt.display_num:
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, short_path, True)
            '''
            # 计算loss值
            '''
            errors, sparse_c_loss = model.get_current_errors()
            visualizer.print_current_errors(epoch, epoch_iter, short_path,
Ejemplo n.º 5
0
def cyclgan_train(opt, cycle_gan: CycleGANModel,
                  cycle_controller: CycleControllerModel, train_loader,
                  g_loss_history: RunningStats, d_loss_history: RunningStats,
                  writer_dict):
    cycle_gan.train()
    cycle_controller.eval()

    dynamic_reset = False
    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in range(opt.shared_epoch):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_controller.forward()

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            g_loss_history.push(cycle_gan.loss_G.item())
            d_loss_history.push(cycle_gan.loss_D_A.item() +
                                cycle_gan.loss_D_B.item())

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if g_loss_history.is_full():
                if g_loss_history.get_var() < opt.dynamic_reset_threshold \
                        or d_loss_history.get_var() < opt.dynamic_reset_threshold:
                    dynamic_reset = True
                    tqdm.write("=> dynamic resetting triggered")
                    g_loss_history.clear()
                    d_loss_history.clear()
                    return dynamic_reset

            if (
                    total_iters + 1
            ) % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                # cycle_gan.save_networks(train_steps)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            # cycle_gan.save_networks(train_steps)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1

    return dynamic_reset
Ejemplo n.º 6
0
            opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_learning_rate(
        )  # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                print_current_losses(log_file, epoch, epoch_iter, losses,
                                     t_comp, t_data)

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
                      (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)
Ejemplo n.º 7
0
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    model = CycleGANModel(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    # create results dir
    image_dir = create_results_dir(opt)
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()  # get image paths
        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))
        save_images(opt,
                    image_dir,
                    visuals,
                    img_path,
                    aspect_ratio=opt.aspect_ratio,
                    width=opt.display_winsize)