Exemplo n.º 1
0
def main():
    opt = ArchTrainOptions().parse()
    torch.cuda.manual_seed(12345)

    opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    cycle_gan = CycleGANModel(opt)
    cycle_gan.setup(opt)
    cycle_gan.set_arch(opt.arch, opt.n_resnet - 1)

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'train_steps': 0
    }

    # for i, data in tqdm(enumerate(dataset)):
    #     cycle_gan.set_input(data)
    #     cycle_gan.forward()
    #     cycle_gan.compute_visuals()
    #     save_current_results(opt, cycle_gan.get_current_visuals(), i)

    cyclgan_train(opt, cycle_gan, dataset, writer_dict)
    def __init__(self):
        opt = TestOptions().parse()  # get test options
        # init pygame
        pygame.init()
        self.size = (256, 256)
        self.screen = pygame.display.set_mode(self.size)
        self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64)
        self.time = pygame.time.get_ticks()
        #self.surface_test = pygame.surfarray.make_surface()
        self.screen.fill(pygame.Color(255, 255, 255))
        pygame.display.flip()

        self.model = CycleGANModel(opt)
        self.model.setup(opt)
        #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9)
        #self.net = init_net(net, 'normal', 0.02, [])
        impath = os.getcwd() + "/datasets/bird/testA/514.png"

        image = pygame.image.load(impath)
Exemplo n.º 3
0
from models.cycle_gan_with_distillation import CycleGANModelWithDistillation
from models.cycle_gan_model import CycleGANModel
if __name__ == '__main__':
    opt = TrainOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.num_threads = 0  # test code only supports num_threads = 1
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    opt2 = Namespace(**vars(opt))
    opt2.name = 'monet2photo_pretrained'
    opt2.isTrain = False
    teacher = CycleGANModel(opt2)
    teacher.isTeacher = True
    opt.continue_train = True
    opt2.continue_train = True
    teacher.setup(
        opt2)  # regular setup: load and print networks; create schedulers
    opt.netG = 'resnet_3blocks'
    opt.results_dir = 'results'
    model = CycleGANModelWithDistillation(
        opt, teacher)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(
        opt.phase, opt.epoch))  # define the website directory
    if opt.load_iter > 0:  # load_iter is 0 by default
Exemplo n.º 4
0
import numpy as np

from visdom import Visdom
viz = Visdom()

assert viz.check_connection()
viz.close()

opt = TrainOptions().parse()
save_opt(opt)

data_loader = DataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

model = CycleGANModel()
model.initialize(opt)
visualizer = Visualizer(opt)

if __name__ == '__main__':

    total_steps = 0
    sparse_c_loss_points, sparse_c_loss_avr_points = [], []

    win_sparse_C = viz.line(X=torch.zeros((1, )),
                            Y=torch.zeros((1, )),
                            name="win_sparse_C")

    for epoch in range(1, opt.epoch + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
Exemplo n.º 5
0
def main():
    opt = SearchOptions().parse()
    torch.cuda.manual_seed(12345)

    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    start_search_iter = 0
    cur_stage = 1

    delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \
                       [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)]

    opt.max_search_iter = sum(delta_grow_steps)
    grow_steps = [
        sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps))
    ][1:]

    grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps)

    if opt.load_path:
        print(f'=> resuming from {opt.load_path}')
        assert os.path.exists(opt.load_path)
        checkpoint_file = os.path.join(opt.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        start_search_iter = checkpoint["search_iter"]
        opt.path_helper = checkpoint['path_helper']

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

        cycle_gan.load_from_state(checkpoint["cycle_gan"])
        cycle_controller.load_from_state(checkpoint["cycle_controller"])

    else:
        opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'controller_steps': start_search_iter * opt.ctrl_step,
        'train_steps': start_search_iter * opt.shared_epoch
    }

    g_loss_history = RunningStats(opt.dynamic_reset_window)
    d_loss_history = RunningStats(opt.dynamic_reset_window)

    dynamic_reset = None
    for search_iter in tqdm(
            range(int(start_search_iter), int(opt.max_search_iter))):
        tqdm.write(f"<start search iteration {search_iter}>")
        cycle_controller.reset()

        if search_iter in grow_steps:
            cur_stage = grow_ctrler.cur_stage(search_iter) + 1
            tqdm.write(f'=> grow to stage {cur_stage}')
            prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A(
            )
            prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B(
            )

            del cycle_controller

            cycle_controller = CycleControllerModel(opt, cur_stage)
            cycle_controller.setup(opt)
            cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B,
                                 prev_archs_A, prev_archs_B)

        dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller,
                                      dataset, g_loss_history, d_loss_history,
                                      writer_dict)

        controller_train(opt, cycle_gan, cycle_controller, writer_dict)

        if dynamic_reset:
            tqdm.write('re-initialize share GAN')
            del cycle_gan
            cycle_gan = CycleGANModel(opt)
            cycle_gan.setup(opt)

        save_checkpoint(
            {
                'cur_stage':
                cur_stage,
                'search_iter':
                search_iter + 1,
                'cycle_gan':
                cycle_gan.save_networks(epoch=search_iter),
                'cycle_controller':
                cycle_controller.save_networks(epoch=search_iter),
                'path_helper':
                opt.path_helper
            }, False, opt.path_helper['ckpt_path'])

    final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A()
    final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B()
    print(f"discovered archs: {final_archs_A}")
    print(f"discovered archs: {final_archs_B}")
Exemplo n.º 6
0
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models.cycle_gan_model import CycleGANModel
from util.util import create_log_txt, print_current_losses

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.

    model = CycleGANModel(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    log_file = create_log_txt(opt)  # create a log file for training progress
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_learning_rate(
        )  # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
Exemplo n.º 7
0
    ngf=64,
    no_dropout=True,
    no_flip=True,
    norm="instance",
    ntest=float("inf"),
    num_test=100,
    num_threads=0,
    output_nc=3,
    phase="test",
    preprocess="no_preprocessing",
    results_dir="./results/",
    serial_batches=True,
    suffix="",
    verbose=False,
)
model = CycleGANModel(opt).netG_A
model.load_state_dict(torch.load(model_fp))

preprocess = get_transform(opt)


class SingleImageDataset(torch.utils.data.Dataset):
    def __init__(self, *args, **kwargs):
        img = kwargs.pop("img")
        super().__init__(*args, **kwargs)
        img = preprocess(img)
        self.img = img

    def __getitem__(self, i):
        return self.img