Ejemplo n.º 1
0
def main():
    G = Generator(args.dim_disc + args.dim_cont)
    D = Discriminator()

    if os.path.isfile(args.model):
        model = torch.load(args.model)
        G.load_state_dict(model[0])
        D.load_state_dict(model[1])

    if use_cuda:
        G.cuda()
        D.cuda()

    if args.mode == "train":
        G, D = train(G, D)
        if args.model:
            torch.save([G.state_dict(), D.state_dict()],
                       args.model,
                       pickle_protocol=4)
    elif args.mode == "gen":
        gen(G)
Ejemplo n.º 2
0
    os.makedirs(args.save_image_dir)
    os.makedirs(args.tensorboard_dir)
    os.makedirs(args.save_model_dir)
    WRITER = SummaryWriter(args.tensorboard_dir)  # Set up TensorBoard.
else:
    print('Dry run! Just for testing, data is not saved')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the GAN.
discriminator_model = Discriminator().to(DEVICE)
generator_model = Generator().to(DEVICE)

# Load pre-trained models if they are provided.
if args.load_discriminator_model_path:
    discriminator_model.load_state_dict(
        torch.load(args.load_discriminator_model_path))

if args.load_generator_model_path:
    generator_model.load_state_dict(torch.load(args.load_generator_model_path))

# Set up Adam optimizers for both models.
discriminator_optimizer = optim.Adam(discriminator_model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0, 0.9))
generator_optimizer = optim.Adam(generator_model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(0, 0.9))

# Create a random batch of latent space vectors that will be used to visualize the progression of the generator.
fixed_latent_space_vectors = torch.randn(64, 512, 1, 1, device=DEVICE)
Ejemplo n.º 3
0
NetG = Decoder(nc, ngf, nz).to(device)
NetD = Discriminator(imageSize, nc, ndf, nz).to(device)
NetE = Encoder(imageSize, nc, ngf, nz).to(device)
Sampler = Sampler().to(device)

NetE.apply(weights_init)
NetG.apply(weights_init)
NetD.apply(weights_init)

# load weights
if opt.netE != '':
    NetE.load_state_dict(torch.load(opt.netE))
if opt.netG != '':
    NetG.load_state_dict(torch.load(opt.netG))
if opt.netD != '':
    NetD.load_state_dict(torch.load(opt.netD))

optimizer_encorder = optim.RMSprop(params=NetE.parameters(),
                                   lr=lr,
                                   alpha=0.9,
                                   eps=1e-8,
                                   weight_decay=0,
                                   momentum=0,
                                   centered=False)
optimizer_decoder = optim.RMSprop(params=NetG.parameters(),
                                  lr=lr,
                                  alpha=0.9,
                                  eps=1e-8,
                                  weight_decay=0,
                                  momentum=0,
                                  centered=False)
Ejemplo n.º 4
0
# training loop
print('training...')
writer = SummaryWriter()
global_step = 0
start_epoch = 0
timestamp = time.time()
# check resume

if args.resume:
	save_path = os.path.join(
		args.save_path,
		'latest_{}.pth'
	)

	G.load_state_dict(torch.load(save_path.format('G')))
	Ds.load_state_dict(torch.load(save_path.format('Ds')))
	Dc.load_state_dict(torch.load(save_path.format('Dc')))
	optimG.load_state_dict(torch.load(save_path.format('optimG')))
	optimDs.load_state_dict(torch.load(save_path.format('optimDs')))
	optimDc.load_state_dict(torch.load(save_path.format('optimDc')))
	global_step, start_epoch = torch.load(save_path.format('state'))
	print('resumed from Epoch: {:04d} Step: {:07d}'.format(start_epoch, global_step))
else:
	clear_dir(args.sample_path)
	clear_dir(args.save_path)

for epoch in range(start_epoch, args.epochs):
	for s1, s2, s3, contour in dataloader:
		# s1 s2 are in same cluster in lab space
		# s3 contour are paired icon and it's contour
		global_step += 1
Ejemplo n.º 5
0
class Solver(object):

    def __init__(self, configuration):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # retrieve configuration variables
        self.data_path = configuration.data_path
        self.crop_size = configuration.crop_size
        self.final_size = configuration.final_size
        self.batch_size = configuration.batch_size
        self.alternating_step = configuration.alternating_step
        self.ncritic = configuration.ncritic
        self.lambda_gp = configuration.lambda_gp
        self.debug_step = configuration.debug_step
        self.save_step = configuration.save_step
        self.max_checkpoints = configuration.max_checkpoints
        self.log_step = configuration.log_step
        # self.tflogger = Logger(configuration.log_dir)
        ## directoriess
        self.train_dir = configuration.train_dir
        self.img_dir = configuration.img_dir
        self.models_dir = configuration.models_dir
        ## variables
        self.eps_drift = 0.001

        self.resume_training = configuration.resume_training

        self._initialise_networks()

    def _initialise_networks(self):
        self.generator = Generator(final_size=self.final_size)
        self.generator.generate_network()
        self.g_optimizer = Adam(self.generator.parameters(), lr=0.001, betas=(0, 0.99))

        self.discriminator = Discriminator(final_size=self.final_size)
        self.discriminator.generate_network()
        self.d_optimizer = Adam(self.discriminator.parameters(), lr=0.001, betas=(0, 0.99))

        self.num_channels = min(self.generator.num_channels,
                                self.generator.max_channels)
        self.upsample = [Upsample(scale_factor=2**i)
                for i in reversed(range(self.generator.num_blocks))]

    def print_debugging_images(self, generator, latent_vectors, shape, index,
                               alpha, iteration):
        with torch.no_grad():
            columns = []
            for i in range(shape[0]):
                row = []
                for j in range(shape[1]):
                    img_ij = generator(latent_vectors[i * shape[1] +
                                                      j].unsqueeze_(0),
                                       index, alpha)
                    img_ij = self.upsample[index](img_ij)
                    row.append(img_ij)
                columns.append(torch.cat(row, dim=3))
            debugging_image = torch.cat(columns, dim=2)
        # denorm
        debugging_image = (debugging_image + 1) / 2
        debugging_image.clamp_(0, 1)
        save_image(debugging_image.data,
                   os.path.join(self.img_dir, "debug_{}_{}.png".format(index,
                                                                      iteration)))

    def save_trained_networks(self, block_index, phase, step):
        models_file = os.path.join(self.models_dir, "models.json")
        if os.path.isfile(models_file):
            with open(models_file, 'r') as file:
                models_config = json.load(file)
        else:
            models_config = json.loads('{ "checkpoints": [] }')

        generator_save_name = "generator_{}_{}_{}.pth".format(
                                    block_index, phase, step
                                )
        torch.save(self.generator.state_dict(),
                   os.path.join(self.models_dir, generator_save_name))

        discriminator_save_name = "discriminator_{}_{}_{}.pth".format(
                                    block_index, phase, step
                                )
        torch.save(self.discriminator.state_dict(),
                   os.path.join(self.models_dir, discriminator_save_name))

        models_config["checkpoints"].append(OrderedDict({
            "block_index": block_index,
            "phase": phase,
            "step": step,
            "generator": generator_save_name,
            "discriminator": discriminator_save_name
        }))
        if len(models_config["checkpoints"]) > self.max_checkpoints:
            old_save = models_config["checkpoints"][0]
            os.remove(os.path.join(self.models_dir, old_save["generator"]))
            os.remove(os.path.join(self.models_dir, old_save["discriminator"]))
            models_config["checkpoints"] = models_config["checkpoints"][1:]
        with open(os.path.join(self.models_dir, "models.json"), 'w') as file:
            json.dump(models_config, file, indent=4)

    def load_trained_networks(self):
        models_file = os.path.join(self.models_dir, "models.json")
        if os.path.isfile(models_file):
            with open(models_file, 'r') as file:
                models_config = json.load(file)
        else:
            raise FileNotFoundError("File 'models.json' not found in {"
                                    "}".format(self.models_dir))

        last_checkpoint = models_config["checkpoints"][-1]
        block_index = last_checkpoint["block_index"]
        phase = last_checkpoint["phase"]
        step = last_checkpoint["step"]
        generator_save_name = os.path.join(
            self.models_dir, last_checkpoint["generator"])
        discriminator_save_name = os.path.join(
            self.models_dir, last_checkpoint["discriminator"])

        self.generator.load_state_dict(torch.load(generator_save_name))
        self.discriminator.load_state_dict(torch.load(discriminator_save_name))

        return  block_index, phase, step

    def train(self):
        # get debugging vectors
        N = (5, 10)
        debug_vectors = torch.randn(N[0] * N[1], self.num_channels, 1,
                                    1).to(self.device)

        # get loader
        loader = get_loader(self.data_path, self.crop_size, self.batch_size)

        losses = {
            "d_loss_real": None,
            "d_loss_fake": None,
            "g_loss": None
        }

        # resume training if needed
        if self.resume_training:
            start_index, start_phase, start_step = self.load_trained_networks()
        else:
            start_index, start_phase, start_step = (0, "fade", 0)

        # training loop
        start_time = time.time()
        absolute_step = -1
        for index in range(start_index, self.generator.num_blocks):
            loader.dataset.set_transform_by_index(index)
            data_iterator = iter(loader)
            for phase in ('fade', 'stabilize'):
                if index == 0 and phase == 'fade': continue
                if self.resume_training and \
                        index == start_index and \
                        phase is not start_phase:
                    continue #
                if phase == 'phade': self.alternating_step = 10000 #FIXME del
                print("index: {}, size: {}x{}, phase: {}".format(
                    index, 2 ** (index + 2), 2 ** (index + 2), phase))
                if self.resume_training and \
                        phase == start_phase     and \
                        index == start_index:
                    step_range = range(start_step, self.alternating_step)
                else:
                    step_range = range(self.alternating_step)
                for i in step_range:
                    absolute_step += 1
                    try:
                        batch = next(data_iterator)
                    except:
                        data_iterator = iter(loader)
                        batch = next(data_iterator)

                    alpha = i / self.alternating_step if phase == "fade" else 1.0

                    batch = batch.to(self.device)

                    d_loss_real = - torch.mean(
                        self.discriminator(batch, index, alpha))
                    losses["d_loss_real"] = torch.mean(d_loss_real).data[0]

                    latent = torch.randn(
                        batch.size(0), self.num_channels, 1, 1).to(self.device)
                    fake_batch = self.generator(latent, index, alpha).detach()
                    d_loss_fake = torch.mean(
                        self.discriminator(fake_batch, index, alpha))
                    losses["d_loss_fake"] = torch.mean(d_loss_fake).data[0]

                    # drift factor
                    drift = d_loss_real.pow(2) + d_loss_fake.pow(2)

                    d_loss = d_loss_real + d_loss_fake + self.eps_drift * drift
                    self.d_optimizer.zero_grad()
                    d_loss.backward()  # if retain_graph=True
                    # then gp works but I'm not sure it's right
                    self.d_optimizer.step()

                    # Compute gradient penalty
                    alpha_gp = torch.rand(batch.size(0), 1, 1, 1).to(self.device)
                    # mind that x_hat must be both detached from the previous
                    # gradient graph (from fake_barch) and with
                    # requires_graph=True so that the gradient can be computed
                    x_hat = (alpha_gp * batch + (1 - alpha_gp) *
                             fake_batch).requires_grad_(True)
                    # x_hat = torch.cuda.FloatTensor(x_hat).requires_grad_(True)
                    out = self.discriminator(x_hat, index, alpha)
                    grad = torch.autograd.grad(
                        outputs=out,
                        inputs=x_hat,
                        grad_outputs=torch.ones_like(out).to(self.device),
                        retain_graph=True,
                        create_graph=True,
                        only_inputs=True
                    )[0]
                    grad = grad.view(grad.size(0), -1)  # is this the same as
                    # detach?
                    l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                    d_loss_gp = torch.mean((l2norm - 1) ** 2)

                    d_loss_gp *= self.lambda_gp
                    self.d_optimizer.zero_grad()
                    d_loss_gp.backward()
                    self.d_optimizer.step()

                    # train generator
                    if (i + 1) % self.ncritic == 0:
                        latent = torch.randn(
                            self.batch_size, self.num_channels, 1, 1).to(self.device)
                        fake_batch = self.generator(latent, index, alpha)
                        g_loss = - torch.mean(self.discriminator(
                                                    fake_batch, index, alpha))
                        losses["g_loss"] = torch.mean(g_loss).data[0]
                        self.g_optimizer.zero_grad()
                        g_loss.backward()
                        self.g_optimizer.step()

                    # tensorboard logging
                    if (i + 1) % self.log_step == 0:
                        elapsed = time.time() - start_time
                        elapsed = str(datetime.timedelta(seconds=elapsed))
                        print("{}:{}:{}/{} time {}, d_loss_real {}, "
                              "d_loss_fake {}, "
                              "g_loss {}, alpha {}".format(index, phase, i,
                                                           self.alternating_step,
                                                           elapsed,
                                                           d_loss_real,
                                              d_loss_fake,
                                              g_loss, alpha))
                        for name, value in losses.items():
                            self.tflogger.scalar_summary(name, value, absolute_step)


                    # print debugging images
                    if (i + 1) % self.debug_step == 0:
                        self.print_debugging_images(
                            self.generator, debug_vectors, N, index, alpha, i)

                    # save trained networks
                    if (i + 1) % self.save_step == 0:
                        self.save_trained_networks(index, phase, i)
Ejemplo n.º 6
0
#instialize weights
if opt.start_epoch == 0:
    netG_A2B.apply(weights_init)
    netG_B2A.apply(weights_init)
    netD_A.apply(weights_init)
    netD_B.apply(weights_init)
    print("training start from begining")
else:  #read trained network param
    netG_A2B.load_state_dict(
        torch.load('%s/%s_netG_A2B_ep%s.pth' %
                   (result_model_path, opt.model_name, opt.start_epoch)))
    netG_B2A.load_state_dict(
        torch.load('%s/%s_netG_B2A_ep%s.pth' %
                   (result_model_path, opt.model_name, opt.start_epoch)))
    netD_A.load_state_dict(
        torch.load('%s/%s_netD_A_ep%s.pth' %
                   (result_model_path, opt.model_name, opt.start_epoch)))
    netD_B.load_state_dict(
        torch.load('%s/%s_netD_B_ep%s.pth' %
                   (result_model_path, opt.model_name, opt.start_epoch)))
    print("training start from epoch %s" % (opt.start_epoch))

#print(netG_A2B,netG_B2A,netD_A,netD_B)
summary(netG_A2B, input_size=(3, 256, 256))
summary(netG_B2A, input_size=(3, 256, 256))
summary(netD_A, input_size=(3, 256, 256))
summary(netD_B, input_size=(3, 256, 256))

# Lossess
gan_loss = torch.nn.MSELoss()  #LSGAN
cycle_consistency_loss = torch.nn.L1Loss()
Ejemplo n.º 7
0
class tag2pix(object):
    def __init__(self, args):
        if args.model == 'tag2pix':
            from network import Generator
        elif args.model == 'senet':
            from model.GD_senet import Generator
        elif args.model == 'resnext':
            from model.GD_resnext import Generator
        elif args.model == 'catconv':
            from model.GD_cat_conv import Generator
        elif args.model == 'catall':
            from model.GD_cat_all import Generator
        elif args.model == 'adain':
            from model.GD_adain import Generator
        elif args.model == 'seadain':
            from model.GD_seadain import Generator
        else:
            raise Exception('invalid model name: {}'.format(args.model))

        self.args = args
        self.epoch = args.epoch
        self.batch_size = args.batch_size

        self.gpu_mode = not args.cpu
        self.input_size = args.input_size
        self.color_revert = ColorSpace2RGB(args.color_space)
        self.layers = args.layers
        [self.cit_weight, self.cvt_weight] = args.cit_cvt_weight

        self.load_dump = (args.load is not "")

        self.load_path = Path(args.load)

        self.l1_lambda = args.l1_lambda
        self.guide_beta = args.guide_beta
        self.adv_lambda = args.adv_lambda
        self.save_freq = args.save_freq

        self.two_step_epoch = args.two_step_epoch
        self.brightness_epoch = args.brightness_epoch
        self.save_all_epoch = args.save_all_epoch

        self.iv_dict, self.cv_dict, self.id_to_name = get_tag_dict(
            args.tag_dump)

        cvt_class_num = len(self.cv_dict.keys())
        cit_class_num = len(self.iv_dict.keys())
        self.class_num = cvt_class_num + cit_class_num

        self.start_epoch = 1

        #### load dataset
        if not args.test:
            self.train_data_loader, self.test_data_loader = get_dataset(args)
            self.result_path = Path(args.result_dir) / time.strftime(
                '%y%m%d-%H%M%S', time.localtime())

            if not self.result_path.exists():
                self.result_path.mkdir()

            self.test_images = self.get_test_data(self.test_data_loader,
                                                  args.test_image_count)
        else:
            self.test_data_loader = get_dataset(args)
            self.result_path = Path(args.result_dir)

        ##### initialize network
        self.net_opt = {
            'guide': not args.no_guide,
            'relu': args.use_relu,
            'bn': not args.no_bn,
            'cit': not args.no_cit
        }

        if self.net_opt['cit']:
            self.Pretrain_ResNeXT = se_resnext_half(
                dump_path=args.pretrain_dump,
                num_classes=cit_class_num,
                input_channels=1)
        else:
            self.Pretrain_ResNeXT = nn.Sequential()

        self.G = Generator(input_size=args.input_size,
                           layers=args.layers,
                           cv_class_num=cvt_class_num,
                           iv_class_num=cit_class_num,
                           net_opt=self.net_opt)
        self.D = Discriminator(input_dim=3,
                               output_dim=1,
                               input_size=self.input_size,
                               cv_class_num=cvt_class_num,
                               iv_class_num=cit_class_num)

        for param in self.Pretrain_ResNeXT.parameters():
            param.requires_grad = False
        if args.test:
            for param in self.G.parameters():
                param.requires_grad = False
            for param in self.D.parameters():
                param.requires_grad = False

        self.Pretrain_ResNeXT = nn.DataParallel(self.Pretrain_ResNeXT)
        self.G = nn.DataParallel(self.G)
        self.D = nn.DataParallel(self.D)

        self.G_optimizer = optim.Adam(self.G.parameters(),
                                      lr=args.lrG,
                                      betas=(args.beta1, args.beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(),
                                      lr=args.lrD,
                                      betas=(args.beta1, args.beta2))

        self.BCE_loss = nn.BCELoss()
        self.CE_loss = nn.CrossEntropyLoss()
        self.L1Loss = nn.L1Loss()

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print("gpu mode: ", self.gpu_mode)
        print("device: ", self.device)
        print(torch.cuda.device_count(), "GPUS!")

        if self.gpu_mode:
            self.Pretrain_ResNeXT.to(self.device)
            self.G.to(self.device)
            self.D.to(self.device)
            self.BCE_loss.to(self.device)
            self.CE_loss.to(self.device)
            self.L1Loss.to(self.device)

    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size,
                                                1), torch.zeros(
                                                    self.batch_size, 1)

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.to(
                self.device), self.y_fake_.to(self.device)

        if self.load_dump:
            self.load(self.load_path)
            print("continue training!!!!")
        else:
            self.end_epoch = self.epoch

        self.print_params()

        self.D.train()
        print('training start!!')
        start_time = time.time()

        for epoch in range(self.start_epoch, self.end_epoch + 1):
            print("EPOCH: {}".format(epoch))

            self.G.train()
            epoch_start_time = time.time()

            if epoch == self.brightness_epoch:
                print('changing brightness ...')
                self.train_data_loader.dataset.enhance_brightness(
                    self.input_size)

            max_iter = self.train_data_loader.dataset.__len__(
            ) // self.batch_size

            for iter, (original_, sketch_, iv_tag_, cv_tag_) in enumerate(
                    tqdm(self.train_data_loader, ncols=80)):
                if iter >= max_iter:
                    break

                if self.gpu_mode:
                    sketch_, original_, iv_tag_, cv_tag_ = sketch_.to(
                        self.device), original_.to(self.device), iv_tag_.to(
                            self.device), cv_tag_.to(self.device)

                # update D network
                self.D_optimizer.zero_grad()

                with torch.no_grad():
                    feature_tensor = self.Pretrain_ResNeXT(sketch_)
                if self.gpu_mode:
                    feature_tensor = feature_tensor.to(self.device)

                D_real, CIT_real, CVT_real = self.D(original_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_f, _ = self.G(sketch_, feature_tensor, cv_tag_)
                if self.gpu_mode:
                    G_f = G_f.to(self.device)

                D_f_fake, CIT_f_fake, CVT_f_fake = self.D(G_f)
                D_f_fake_loss = self.BCE_loss(D_f_fake, self.y_fake_)

                if self.two_step_epoch == 0 or epoch >= self.two_step_epoch:
                    CIT_real_loss = self.BCE_loss(
                        CIT_real, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_real_loss = self.BCE_loss(CVT_real, cv_tag_)

                    C_real_loss = self.cvt_weight * CVT_real_loss + self.cit_weight * CIT_real_loss

                    CIT_f_fake_loss = self.BCE_loss(
                        CIT_f_fake, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_f_fake_loss = self.BCE_loss(CVT_f_fake, cv_tag_)

                    C_f_fake_loss = self.cvt_weight * CVT_f_fake_loss + self.cit_weight * CIT_f_fake_loss
                else:
                    C_real_loss = 0
                    C_f_fake_loss = 0

                D_loss = self.adv_lambda * (D_real_loss + D_f_fake_loss) + (
                    C_real_loss + C_f_fake_loss)

                self.train_hist['D_loss'].append(D_loss.item())

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_f, G_g = self.G(sketch_, feature_tensor, cv_tag_)

                if self.gpu_mode:
                    G_f, G_g = G_f.to(self.device), G_g.to(self.device)

                D_f_fake, CIT_f_fake, CVT_f_fake = self.D(G_f)

                D_f_fake_loss = self.BCE_loss(D_f_fake, self.y_real_)

                if self.two_step_epoch == 0 or epoch >= self.two_step_epoch:
                    CIT_f_fake_loss = self.BCE_loss(
                        CIT_f_fake, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_f_fake_loss = self.BCE_loss(CVT_f_fake, cv_tag_)

                    C_f_fake_loss = self.cvt_weight * CVT_f_fake_loss + self.cit_weight * CIT_f_fake_loss
                else:
                    C_f_fake_loss = 0

                L1_D_f_fake_loss = self.L1Loss(G_f, original_)
                L1_D_g_fake_loss = self.L1Loss(
                    G_g, original_) if self.net_opt['guide'] else 0

                G_loss = (D_f_fake_loss + C_f_fake_loss) + \
                         (L1_D_f_fake_loss + L1_D_g_fake_loss * self.guide_beta) * self.l1_lambda

                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print(
                        "Epoch: [{:2d}] [{:4d}/{:4d}] D_loss: {:.8f}, G_loss: {:.8f}"
                        .format(epoch, (iter + 1), max_iter, D_loss.item(),
                                G_loss.item()))

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)

            with torch.no_grad():
                self.visualize_results(epoch)
                utils.loss_plot(self.train_hist, self.result_path, epoch)

            if epoch >= self.save_all_epoch > 0:
                self.save(epoch)
            elif self.save_freq > 0 and epoch % self.save_freq == 0:
                self.save(epoch)

        print("Training finish!... save training results")

        if self.save_freq == 0 or epoch % self.save_freq != 0:
            if self.save_all_epoch <= 0 or epoch < self.save_all_epoch:
                self.save(epoch)

        self.train_hist['total_time'].append(time.time() - start_time)
        print(
            "Avg one epoch time: {:.2f}, total {} epochs time: {:.2f}".format(
                np.mean(self.train_hist['per_epoch_time']), self.epoch,
                self.train_hist['total_time'][0]))

    def test(self):
        self.load_test(self.args.load)

        self.D.eval()
        self.G.eval()

        load_path = self.load_path
        result_path = self.result_path / load_path.stem

        if not result_path.exists():
            result_path.mkdir()

        with torch.no_grad():
            for sketch_, index_, _, cv_tag_ in tqdm(self.test_data_loader,
                                                    ncols=80):
                if self.gpu_mode:
                    sketch_, cv_tag_ = sketch_.to(self.device), cv_tag_.to(
                        self.device)

                with torch.no_grad():
                    feature_tensor = self.Pretrain_ResNeXT(sketch_)

                if self.gpu_mode:
                    feature_tensor = feature_tensor.to(self.device)

                # D_real, CIT_real, CVT_real = self.D(original_)
                G_f, _ = self.G(sketch_, feature_tensor, cv_tag_)
                G_f = self.color_revert(G_f.cpu())

                for ind, result in zip(index_.cpu().numpy(), G_f):
                    save_path = result_path / f'{ind}.png'
                    if save_path.exists():
                        for i in range(100):
                            save_path = result_path / f'{ind}_{i}.png'
                            if not save_path.exists():
                                break
                    img = Image.fromarray(result)
                    img.save(save_path)

    def visualize_results(self, epoch, fix=True):
        if not self.result_path.exists():
            self.result_path.mkdir()

        self.G.eval()

        # test_data_loader
        original_, sketch_, iv_tag_, cv_tag_ = self.test_images
        image_frame_dim = int(np.ceil(np.sqrt(len(original_))))

        # iv_tag_ to feature tensor 16 * 16 * 256 by pre-reained Sketch.
        with torch.no_grad():
            feature_tensor = self.Pretrain_ResNeXT(sketch_)

            if self.gpu_mode:
                original_, sketch_, iv_tag_, cv_tag_, feature_tensor = original_.to(
                    self.device), sketch_.to(self.device), iv_tag_.to(
                        self.device), cv_tag_.to(
                            self.device), feature_tensor.to(self.device)

            G_f, G_g = self.G(sketch_, feature_tensor, cv_tag_)

            if self.gpu_mode:
                G_f = G_f.cpu()
                G_g = G_g.cpu()

            G_f = self.color_revert(G_f)
            G_g = self.color_revert(G_g)

        utils.save_images(
            G_f[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_epoch{:03d}_G_f.png'.format(epoch))
        utils.save_images(
            G_g[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_epoch{:03d}_G_g.png'.format(epoch))

    def save(self, save_epoch):
        if not self.result_path.exists():
            self.result_path.mkdir()

        with (self.result_path / 'arguments.txt').open('w') as f:
            f.write(pprint.pformat(self.args.__dict__))

        save_dir = self.result_path

        torch.save(
            {
                'G': self.G.state_dict(),
                'D': self.D.state_dict(),
                'G_optimizer': self.G_optimizer.state_dict(),
                'D_optimizer': self.D_optimizer.state_dict(),
                'finish_epoch': save_epoch,
                'result_path': str(save_dir)
            }, str(save_dir / 'tag2pix_{}_epoch.pkl'.format(save_epoch)))

        with (save_dir /
              'tag2pix_{}_history.pkl'.format(save_epoch)).open('wb') as f:
            pickle.dump(self.train_hist, f)

        print("============= save success =============")
        print("epoch from {} to {}".format(self.start_epoch, save_epoch))
        print("save result path is {}".format(str(self.result_path)))

    def load_test(self, checkpoint_path):
        checkpoint = torch.load(str(checkpoint_path))
        self.G.load_state_dict(checkpoint['G'])

    def load(self, checkpoint_path):
        checkpoint = torch.load(str(checkpoint_path))
        self.G.load_state_dict(checkpoint['G'])
        self.D.load_state_dict(checkpoint['D'])
        self.G_optimizer.load_state_dict(checkpoint['G_optimizer'])
        self.D_optimizer.load_state_dict(checkpoint['D_optimizer'])
        self.start_epoch = checkpoint['finish_epoch'] + 1

        self.finish_epoch = self.args.epoch + self.start_epoch - 1

        print("============= load success =============")
        print("epoch start from {} to {}".format(self.start_epoch,
                                                 self.finish_epoch))
        print("previous result path is {}".format(checkpoint['result_path']))

    def get_test_data(self, test_data_loader, count):
        test_count = 0
        original_, sketch_, iv_tag_, cv_tag_ = [], [], [], []
        for orig, sket, ivt, cvt in test_data_loader:
            original_.append(orig)
            sketch_.append(sket)
            iv_tag_.append(ivt)
            cv_tag_.append(cvt)

            test_count += len(orig)
            if test_count >= count:
                break

        original_ = torch.cat(original_, 0)
        sketch_ = torch.cat(sketch_, 0)
        iv_tag_ = torch.cat(iv_tag_, 0)
        cv_tag_ = torch.cat(cv_tag_, 0)

        self.save_tag_tensor_name(iv_tag_, cv_tag_,
                                  self.result_path / "test_image_tags.txt")

        image_frame_dim = int(np.ceil(np.sqrt(len(original_))))

        if self.gpu_mode:
            original_ = original_.cpu()
        sketch_np = sketch_.data.numpy().transpose(0, 2, 3, 1)
        original_np = self.color_revert(original_)

        utils.save_images(
            original_np[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_original.png')
        utils.save_images(
            sketch_np[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_sketch.png')

        return original_, sketch_, iv_tag_, cv_tag_

    def save_tag_tensor_name(self, iv_tensor, cv_tensor, save_file_path):
        '''iv_tensor, cv_tensor: batched one-hot tag tensors'''
        iv_dict_inverse = {
            tag_index: tag_id
            for (tag_id, tag_index) in self.iv_dict.items()
        }
        cv_dict_inverse = {
            tag_index: tag_id
            for (tag_id, tag_index) in self.cv_dict.items()
        }

        with open(save_file_path, 'w') as f:
            f.write("CIT tags\n")

            for tensor_i, batch_unit in enumerate(iv_tensor):
                tag_list = []
                f.write(f'{tensor_i} : ')

                for i, is_tag in enumerate(batch_unit):
                    if is_tag:
                        tag_name = self.id_to_name[iv_dict_inverse[i]]
                        tag_list.append(tag_name)
                        f.write(f"{tag_name}, ")
                f.write("\n")

            f.write("\nCVT tags\n")

            for tensor_i, batch_unit in enumerate(cv_tensor):
                tag_list = []
                f.write(f'{tensor_i} : ')

                for i, is_tag in enumerate(batch_unit):
                    if is_tag:
                        tag_name = self.id_to_name[cv_dict_inverse[i]]
                        tag_list.append(self.id_to_name[cv_dict_inverse[i]])
                        f.write(f"{tag_name}, ")
                f.write("\n")

    def print_params(self):
        params_cnt = [0, 0, 0]
        for param in self.G.parameters():
            params_cnt[0] += param.numel()
        for param in self.D.parameters():
            params_cnt[1] += param.numel()
        for param in self.Pretrain_ResNeXT.parameters():
            params_cnt[2] += param.numel()
        print(
            f'Parameter #: G - {params_cnt[0]} / D - {params_cnt[1]} / Pretrain - {params_cnt[2]}'
        )
Ejemplo n.º 8
0
loader_A = torch.utils.data.DataLoader(train_A, batch_size=1, shuffle=True)
loader_B = torch.utils.data.DataLoader(train_B, batch_size=1, shuffle=True)

G_A2B = Generator(9).to(device)
G_B2A = Generator(9).to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

if opt.load_pretrained:
    G_A2B.load_state_dict(torch.load('pretrained/G_A2B.pth'))
    G_B2A.load_state_dict(torch.load('pretrained/G_B2A.pth'))
    D_A.load_state_dict(torch.load('pretrained/D_A.pth'))
    D_B.load_state_dict(torch.load('pretrained/D_B.pth'))
else:
    G_A2B.apply(weights_init)
    G_B2A.apply(weights_init)
    D_A.apply(weights_init)
    D_B.apply(weights_init)

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()


G_params = list(G_A2B.parameters()) + list(G_B2A.parameters())
D_params = list(D_A.parameters()) + list(D_B.parameters())
optimizer_G = optim.Adam(G_params, lr=0.0002, betas=[0.5, 0.999])
Ejemplo n.º 9
0
if opt.WGAN:
    desc +='_WGAN'
if opt.LS:
        desc += '_LS'
if bMirror:
    desc += '_mirror'
if opt.textureScale !=1:
    desc +="_scale"+str(opt.textureScale)
    

# initialise generator and discriminator and load checkpoints if option added

netD = Discriminator(ndf, opt.nDepD, bSigm=False, ncIn=4)
if opt.loadModelD != "":
    netD.load_state_dict(torch.load(opt.loadModelDs))

netG =NetG(ngf, nDep, nz, 4)
if opt.loadModelG != "":
    netG.load_state_dict(torch.load(opt.loadModelG))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print ("device",device)

Gnets=[netG]

for net in [netD] + Gnets:
    try:
        net.apply(weights_init)
    except Exception as e:
        print (e,"weightinit")
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
""" set CUDA """
G.cuda()
D.cuda()
""" Optimizer """
G_optimizer = optim.Adam(G.parameters(), lr=cur_lrG, betas=(0.9, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=cur_lrD, betas=(0.9, 0.999))
""" Restore """
if opt.restore:
    print('==> Restoring from checkpoint..', opt.restore)
    state = torch.load(opt.restore)

    G.load_state_dict(state['G'])
    D.load_state_dict(state['D'])
    G_optimizer.load_state_dict(state["G_optimizer"])
    D_optimizer.load_state_dict(state["D_optimizer"])
    epoch = state["epoch"]
    global_iter += state["iter"]
    cur_lrG = state["lrG"]
    cur_lrD = state["lrD"]
    state = None
""" multi-GPU training  """
G = torch.nn.DataParallel(G, device_ids=range(torch.cuda.device_count()))
D = torch.nn.DataParallel(D, device_ids=range(torch.cuda.device_count()))
""" training mode  """
G.train()
D.train()
""" Loss for GAN """
BCE_loss = nn.BCELoss().cuda()
class AtariGan:
    def __init__(self, config):
        self.config = config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.generator = Generator(config).to(self.device)
        self.discriminator = Discriminator(config).to(self.device)

        self.loss_func = nn.BCELoss()
        self.generator_optimizer = optim.Adam(
            params=self.generator.parameters(),
            lr=config["generator_learning_rate"],
            betas=config["generator_betas"])
        self.discriminator_optimizer = optim.Adam(
            params=self.discriminator.parameters(),
            lr=config["discriminator_learning_rate"],
            betas=config["discriminator_betas"])

        self.true_labels = torch.ones(config["batch_size"],
                                      dtype=torch.float32,
                                      device=self.device)
        self.fake_labels = torch.zeros(config["batch_size"],
                                       dtype=torch.float32,
                                       device=self.device)

    def generate(self, noise_input):
        return self.generator(noise_input)

    def discriminate(self, _input):
        return self.discriminator(_input)

    def train(self, gen_output_v, batch_v):
        discriminator_loss = self._train_discriminator(batch_v, gen_output_v)
        generator_loss = self._train_generator(gen_output_v)
        return generator_loss.item(), discriminator_loss.item()

    def _train_discriminator(self, batch, generator_output):
        loss = self._calc_discriminator_loss(batch, generator_output)
        self._optimize(loss, self.discriminator_optimizer)
        return loss

    def _calc_discriminator_loss(self, batch, generator_output):
        output_true = self.discriminate(batch.to(self.device))
        output_fake = self.discriminate(generator_output.detach())
        return self.loss_func(output_true, self.true_labels) + self.loss_func(
            output_fake, self.fake_labels)

    def _train_generator(self, generator_output):
        loss = self._calc_generator_loss(generator_output)
        self._optimize(loss, self.generator_optimizer)
        return loss

    def _calc_generator_loss(self, generator_output):
        dis_output_v = self.discriminate(generator_output)
        return self.loss_func(dis_output_v, self.true_labels)

    def _optimize(self, loss, optimizer):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    def save(self):
        """Save the network weights"""
        save_dir = os.path.join(".", *self.config["checkpoint_dir"],
                                self.config["project_name"])
        helper.mkdir(save_dir)
        current_date_time = helper.get_current_date_time()
        current_date_time = current_date_time.replace(" ", "__").replace(
            "/", "_").replace(":", "_")

        torch.save(
            self.generator.state_dict(),
            os.path.join(save_dir, "generator_ckpt_" + current_date_time))
        torch.save(
            self.discriminator.state_dict(),
            os.path.join(save_dir, "discriminator_ckpt_" + current_date_time))

    def load(self):
        """Load latest available network weights"""
        load_path = os.path.join(".", *self.config["checkpoint_dir"],
                                 self.config["project_name"], "*")
        list_of_files = glob.glob(load_path)
        list_of_generator_weights = [
            w for w in list_of_files if "generator" in w
        ]
        list_of_discriminator_weights = [
            w for w in list_of_files if "discriminator" in w
        ]
        latest_generator_weights = max(list_of_generator_weights,
                                       key=os.path.getctime)
        self.generator.load_state_dict(torch.load(latest_generator_weights))
        latest_generator_weights = max(list_of_discriminator_weights,
                                       key=os.path.getctime)
        self.discriminator.load_state_dict(
            torch.load(latest_generator_weights))