Exemplo n.º 1
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD, netE
    def initialize_networks(self):
        netD = networks.define_D2(self.opt) if self.opt.isTrain else None
        if self.opt.isTrain and self.opt.which_iter_D2>0:
            print_network(netD)
            netD = util.load_network(netD, 'D2', self.opt.which_iter_D2 , self.opt)

        return netD
Exemplo n.º 3
0
    def run_pretrain(self, load=False):
        if not load:
            global_iteration = 0
            training_steps = int(self.n_samples /self.opt.batchSize)
            self.num_semantics = self.progressive_model.num_semantics
            self.set_data_resolution(int(self.opt.crop_size/self.opt.aspect_ratio))
            print(f"Training at resolution {self.progressive_model.generator.res}")
            dim_ind = 0
            phase = "stabilize"
            scaling = int(self.opt.crop_size / (self.opt.aspect_ratio * self.progressive_model.generator.res))
            num_epochs=1
            for epoch in range(num_epochs):
                for iteration in range(training_steps):
                    seg, _, im, _ = self.next_batch()
                    seg, seg_mc, im = self.call_next_batch(seg,im)
                    D_losses = self.step_discriminator( iteration, global_iteration, dim_ind, seg_mc, seg, im, scaling, phase)

                    global_iteration += 1
                    if (iteration + 1) % 10 == 0:
                        print(
                            f"Res {self.progressive_model.generator.res:03d}, {phase.rjust(9)}: Iteration {iteration + 1:05d}/{training_steps:05d}, epoch:{epoch + 1:05d}/{num_epochs:05d}"
                        )

                if epoch % self.opt.save_epoch_freq == 0 or \
                   (epoch+1) == num_epochs:
                    util.save_network(self.end2end_model_on_one_gpu.netD2, 'D2', global_iteration, self.opt)
        else:
            netD2 = self.end2end_model_on_one_gpu.netD2
            netD2 = util.load_network(netD2, 'D2', self.opt.which_iter_D2, self.opt)
            global_iteration = self.opt.which_iter_D2

        return global_iteration