Example #1
0
 def resume(self, checkpoint_dir, param):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_i.load_state_dict(state_dict['i'])
     self.gen_r.load_state_dict(state_dict['r'])
     self.gen_s.load_state_dict(state_dict['s'])
     if self.with_mapping:
         self.fea_m.load_state_dict(state_dict['fm'])
         self.fea_s.load_state_dict(state_dict['fs'])
     self.best_result = state_dict['best_result']
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_r.load_state_dict(state_dict['r'])
     self.dis_s.load_state_dict(state_dict['s'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #2
0
 def resume(self, checkpoint_dir, configs):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name,
                             map_location=lambda storage, loc: storage)
     self.gen.load_state_dict(state_dict['a'])
     iterations = int(
         last_model_name[-15:-7]) if 'avg' in last_model_name else int(
             last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name,
                             map_location=lambda storage, loc: storage)
     self.dis.load_state_dict(state_dict['b'])
     # Load optimizers
     #state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'), map_location=lambda storage, loc: storage)
     #self.dis_opt.load_state_dict(state_dict['dis'])
     #self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, configs, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, configs, iterations)
     if torch.__version__ != '0.4.1':
         for _ in range(iterations):
             self.gen_scheduler.step()
             self.dis_scheduler.step()
     print('Resume from iteration %d' % iterations)
     return iterations
Example #3
0
 def resume(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_a.load_state_dict(state_dict['a'])
     self.gen_b.load_state_dict(state_dict['b'])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_a.load_state_dict(state_dict['a'])
     self.dis_b.load_state_dict(state_dict['b'])
     # Load content classifier
     last_model_name = get_model_list(checkpoint_dir, "con_cla")
     state_dict = torch.load(last_model_name)
     self.content_classifier.load_state_dict(state_dict['con_cla'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     self.cla_opt.load_state_dict(state_dict['con_cla'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                        iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                        iterations)
     self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters,
                                        iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #4
0
 def resume_prev(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     print('resume from generator checkpoint named:', last_model_name)
     state_dict = torch.load(last_model_name)
     self.gen_a.load_state_dict(state_dict['a'])
     self.gen_b.load_state_dict(state_dict['b'])
     try:
         iterations = int(float(last_model_name[4:11]))
     except:
         iterations = int(float(last_model_name.split('/')[-1][4:11]))
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     print('resume from discriminator checkpoint named:', last_model_name)
     state_dict = torch.load(last_model_name)
     self.dis_a.load_state_dict(state_dict['a'])
     self.dis_b.load_state_dict(state_dict['b'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                        iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                        iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #5
0
    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.models.gen.load_state_dict(state_dict['gen'])
        self.models.gen_test.load_state_dict(state_dict['gen_test'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.models.dis.load_state_dict(state_dict['dis'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()
        
        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

        print('Resume from iteration %d' % iterations)
        return iterations
Example #6
0
 def resume(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_a.load_state_dict(state_dict['a'])
     self.gen_b = self.gen_a
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_a.load_state_dict(state_dict['a'])
     self.dis_b = self.dis_a
     # Load ID dis
     last_model_name = get_model_list(checkpoint_dir, "id")
     state_dict = torch.load(last_model_name)
     self.id_a.load_state_dict(state_dict['a'])
     self.id_b = self.id_a
     # Load optimizers
     try:
         state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
         self.dis_opt.load_state_dict(state_dict['dis'])
         self.gen_opt.load_state_dict(state_dict['gen'])
         self.id_opt.load_state_dict(state_dict['id'])
     except:
         pass
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #7
0
    def resume(self, checkpoint_dir, hyperparameters):
        """
        Resume the training loading the network parameters
        
        Arguments:
            checkpoint_dir {string} -- path to the directory where the checkpoints are saved
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Returns:
            int -- number of iterations (used by the optimizer)
        """
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        if self.gen_state == 0:
            self.gen_a.load_state_dict(state_dict["a"])
            self.gen_b.load_state_dict(state_dict["b"])
        elif self.gen_state == 1:
            self.gen.load_state_dict(state_dict["2"])
        else:
            print("self.gen_state unknown value:", self.gen_state)

        # Load domain classifier
        if self.domain_classif == 1:
            last_model_name = get_model_list(checkpoint_dir, "domain_classif")
            state_dict = torch.load(last_model_name)
            self.domain_classifier.load_state_dict(state_dict["d"])

        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict["a"])
        self.dis_b.load_state_dict(state_dict["b"])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
        self.dis_opt.load_state_dict(state_dict["dis"])
        self.gen_opt.load_state_dict(state_dict["gen"])

        if self.domain_classif == 1:
            self.dann_opt.load_state_dict(state_dict["dann"])
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters,
                                                iterations)
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print("Resume from iteration %d" % iterations)
        return iterations
Example #8
0
    def resume(self, checkpoint_dir, hyperparameters, resume_epoch):

        print("--> " + checkpoint_dir)

        # Load generator.
        last_model_name = get_model_list(checkpoint_dir, "gen", resume_epoch)
        # print('\n',last_model_name)
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup", resume_epoch)
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)
Example #9
0
    def resume(self, checkpoint_dir, hyperparameters):

        print("--> " + checkpoint_dir)

        # Load generator.
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup")
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load discriminator.
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict)

        # Load optimizers.
        last_model_name = get_model_list(checkpoint_dir, "opt")
        state_dict = torch.load(last_model_name)
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        # Reinitilize schedulers.
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           epochs)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           epochs)

        print('Resume from epoch %d' % epochs)
        return epochs
Example #10
0
 def load(self, checkpoint_dir):
     last_model_name = get_model_list(checkpoint_dir, "enc")
     last_epoch = int(last_model_name.split('_')[1].split('.')[0])
     encoder = torch.load(last_model_name)
     self.model.load_state_dict(encoder["model"])
     self.opt.load_state_dict(encoder["opt"])
     self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
         self.opt, T_max=self.T_max, eta_min=0, last_epoch=last_epoch)
     print(f"Resume encoder from epoch {last_epoch + 1}")
     return last_epoch + 1
Example #11
0
    def load_checkpoint(self, model_path=None):
        if not model_path:
            model_dir = self.model_dir
            model_path = get_model_list(model_dir, "gen")  # last model

        state_dict = torch.load(model_path, map_location=self.device)
        self.gen.load_state_dict(state_dict['gen'])

        epochs = int(model_path[-7:-3])
        print('Load from epoch %d' % epochs)

        return epochs
Example #12
0
    def resume(self, checkpoint_dir, hyperparameters, particular=False, checkpoint=''):
        # Load generators
        # if particular false then loads 
        if not particular:
            last_model_name = get_model_list(checkpoint_dir, "vae")
        else:
            if checkpoint == '':
                sys.exit('Specified checkpoint path is empty')
            last_model_name = os.path.join(checkpoint_dir, checkpoint)

        state_dict = torch.load(last_model_name)
        self.vae.load_state_dict(state_dict)
Example #13
0
 def resume(self, checkpoint_dir, hps):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name, map_location=lambda storage, loc: storage)
     self.gen_a.load_state_dict(state_dict['a'])
     self.gen_b.load_state_dict(state_dict['b'])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name, map_location=lambda storage, loc: storage)
     self.dis_a.load_state_dict(state_dict['a'])
     self.dis_b.load_state_dict(state_dict['b'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'), map_location=lambda storage, loc: storage)
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hps, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hps, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #14
0
 def resume(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_a.load_state_dict(state_dict['a'])
     self.gen_b.load_state_dict(state_dict['b'])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_a.load_state_dict(state_dict['a'])
     self.dis_b.load_state_dict(state_dict['b'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #15
0
 def resume(self, checkpoint_dir, param):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "student")
     state_dict = torch.load(last_model_name)
     self.generator.load_state_dict(state_dict['student'])
     self.best_result = state_dict['best_result']
     epoch = int(last_model_name[-6:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "teacher")
     state_dict = torch.load(last_model_name)
     self.discriminator.load_state_dict(state_dict['teacher'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['teacher'])
     self.stu_opt.load_state_dict(state_dict['student'])
     # Re-initialize schedulers
     try:
         self.stu_scheduler = get_scheduler(self.stu_opt, param, epoch)
     except Exception as e:
         print('Warning: {}'.format(e))
     print('Resume from epoch %d' % epoch)
     return epoch
Example #16
0
 def resume(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_AB.load_state_dict(state_dict['AB'])
     self.gen_BA.load_state_dict(state_dict['BA'])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_A.load_state_dict(state_dict['A'])
     self.dis_B.load_state_dict(state_dict['B'])
     self.dis_2.load_state_dict(state_dict['2'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
Example #17
0
    def resume(self, checkpoint_dir, hp):
        last_model_name = get_model_list(checkpoint_dir, "gen")
        if last_model_name:
            state_dict = torch.load(last_model_name)
            self.model.gen.load_state_dict(state_dict['gen'])
            self.model.gen_test.load_state_dict(state_dict['gen_test'])
            iterations = int(last_model_name[-11:-3])

            last_model_name = get_model_list(checkpoint_dir, "dis")
            state_dict = torch.load(last_model_name)
            self.model.dis.load_state_dict(state_dict['dis'])

            last_opt_name = get_model_list(checkpoint_dir, "optimizer")
            state_dict = torch.load(last_opt_name)
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])

            self.dis_scheduler = get_scheduler(self.dis_opt, hp, iterations)
            self.gen_scheduler = get_scheduler(self.gen_opt, hp, iterations)
        else:
            iterations = 0
        print(f'Resume GAN from iteration {iterations}')
        return iterations
Example #18
0
 def resume(self, checkpoint_dir, options):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.encoder.load_state_dict(state_dict['a'])
     self.decoder.load_state_dict(state_dict['b'])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.discriminator.load_state_dict(state_dict['a'])
     # Load optimizers
     last_model_name = get_model_list(checkpoint_dir, "opt")
     state_dict = torch.load(last_model_name)
     self.gen_opt.load_state_dict(state_dict['a'])
     self.dis_opt.load_state_dict(state_dict['b'])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, options, iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, options, iterations)
     print('Resume from iteration %d' % iterations)
     del state_dict, last_model_name
     torch.cuda.empty_cache()
     return iterations
Example #19
0
    def resume(self, checkpoint_dir, param):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.model.load_state_dict(state_dict['i'])
        iterations = int(last_model_name[-11:-3])

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations
Example #20
0
 def resume(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_a.load_state_dict(state_dict["a"])
     self.gen_b.load_state_dict(state_dict["b"])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_a.load_state_dict(state_dict["a"])
     self.dis_b.load_state_dict(state_dict["b"])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
     self.dis_opt.load_state_dict(state_dict["dis"])
     self.gen_opt.load_state_dict(state_dict["gen"])
     # Reinitilize schedulers
     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                        iterations)
     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                        iterations)
     print("Resume from iteration %d" % iterations)
     return iterations
Example #21
0
    def resume(self, checkpoint_dir, hyperparameters, gen_model=None, dis_model=None):
        # Load generators
        if gen_model is None:
            gen_model = get_model_list(checkpoint_dir, "gen") # last gen model
        gen_state_dict = torch.load(gen_model)
        self.gen_a.load_state_dict(gen_state_dict['a'])
        self.gen_b.load_state_dict(gen_state_dict['b'])
        iterations = int(gen_model[-11:-3])

        # Load discriminators
        if dis_model is None:
            dis_model = get_model_list(checkpoint_dir, "dis")
        dis_state_dict = torch.load(dis_model)
        self.dis_a.load_state_dict(dis_state_dict['a'])
        self.dis_b.load_state_dict(dis_state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations
Example #22
0
 def resume(self, checkpoint_dir, param):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.generator.load_state_dict(state_dict['generator'])
     self.best_result = state_dict['best_result']
     epoch = int(last_model_name[-6:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.discriminator_bg.load_state_dict(state_dict['bg'])
     self.discriminator_rf.load_state_dict(state_dict['rf'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     try:
         self.dis_scheduler = get_scheduler(self.dis_opt, param, epoch)
         self.gen_scheduler = get_scheduler(self.gen_opt, param, epoch)
     except Exception as e:
         print('Warning: {}'.format(e))
     print('Resume from epoch %d' % epoch)
     return epoch
    def resume(self, checkpoint_dir, params, version=None):

        model_name = get_model_list(checkpoint_dir, 'model', version)

        state_dict = torch.load(model_name)
        self.model.load_state_dict(state_dict)
        iterations = int(model_name[-12:-4])
        if version is not None: assert version == iterations

        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pth'))
            self.encoder_opt.load_state_dict(state_dict)
        except IOError:
            print 'use a new optimizer'

        # reinitialize scheduler
        self.encoder_scheduler = get_scheduler(self.encoder_opt, params,
                                               iterations)
        print('Resume from iteration %d' % iterations)
        return iterations
Example #24
0
 def resume(self,
            checkpoint_dir,
            hyperparameters,
            need_opt=True,
            path=None):
     # Load generators
     if (path == None):
         last_model_name = get_model_list(checkpoint_dir, "gen")
     else:
         last_model_name = path
     state_dict = torch.load(last_model_name)
     self.gen.load_state_dict(state_dict['a'])
     iterations = int(last_model_name[-11:-3])
     if (need_opt):
         state_dict = torch.load(
             os.path.join(checkpoint_dir,
                          'optimizer_' + last_model_name[-11:-3] + '.pt'))
         self.gen_opt.load_state_dict(state_dict['gen'])
         self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                            iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
def collaborative_private_model_mnist_train(args):

    device = 'cuda' if args.gpu else 'cpu'
    # 用于初始化模型的部分
    train_dataset, test_dataset = get_public_dataset(args)
    models = {
        "2_layer_CNN": CNN_2layer_fc_model_removelogsoftmax,  # 字典的函数类型
        "3_layer_CNN": CNN_3layer_fc_model_removelogsoftmax
    }
    modelsindex = ["2_layer_CNN", "3_layer_CNN"]
    if args.new_collaborative_training:
        model_list, model_type_list = get_model_list(args.privateurl,
                                                     modelsindex, models)
    else:
        model_list, model_type_list = get_model_list(args.Collaborativeurl,
                                                     modelsindex, models)
    epoch_groups = MNIST_random(train_dataset, args.collaborative_epoch)

    train_loss = []
    test_accuracy = []
    for i in range(args.user_number):
        train_loss.append([])
    for i in range(args.user_number):
        test_accuracy.append([])

    for epoch in range(args.collaborative_epoch):

        train_batch_loss = []
        for i in range(args.user_number):
            train_batch_loss.append([])

        trainloader = DataLoader(DatasetSplit(train_dataset,
                                              list(epoch_groups[epoch])),
                                 batch_size=256,
                                 shuffle=True)

        for batch_idx, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)
            # 初始化存储结果的东西
            temp_sum_result = [[] for _ in range(len(labels))]
            for item in range(len(temp_sum_result)):
                for i in range(args.output_classes):
                    temp_sum_result[item].append(0)

            # Make output together
            for n, model in enumerate(model_list):
                with torch.no_grad():
                    model.to(device)
                    model.eval()
                    outputs = model(images)
                    pred_labels = outputs.tolist()  # 转成list
                    # print(pred_labels.shape) # torch.Size([128, 16])
                    # _,pred_labels = torch.max(outputs,1)
                    # pred_labels = pred_labels.view(-1)
                    # print(pred_labels.shape) # torch.Size([2048])
                    temp_sum_result = list_add(pred_labels,
                                               temp_sum_result)  # 把每次的结果都给加到一起
            #         print(len(temp_sum_result))
            #         print(len(temp_sum_result[0]))
            # print(type(temp_sum_result))
            # print(type(temp_sum_result[0]))
            # temp_sum_result = torch.stack(temp_sum_result) # torch.Size([10, 128, 16])
            # temp_sum_result /=args.user_number
            # print(temp_sum_result.shape)
            # labels = torch.mean(temp_sum_result.float(),dim=0) # get the output
            # print(labels.shape) # torch.Size([128, 16])
            temp_sum_result = get_avg_result(
                temp_sum_result, args.user_number)  # 根据参与训练的时候用户把结果除以对应的数量
            labels = torch.tensor(temp_sum_result)
            # print(labels.size())
            # print((labels[0]).size())
            labels = labels.to(device)
            for n, model in enumerate(model_list):
                model.to(device)
                model.train()
                if args.optimizer == 'sgd':
                    optimizer = torch.optim.SGD(model.parameters(),
                                                lr=args.lr,
                                                momentum=0.5)
                elif args.optimizer == 'adam':
                    optimizer = torch.optim.Adam(model.parameters(),
                                                 lr=args.lr,
                                                 weight_decay=1e-4)
                criterion = nn.L1Loss(size_average=None,
                                      reduce=None,
                                      reduction='mean').to(device)
                optimizer.zero_grad()
                outputs = model(images)  # torch.Size([128, 16])
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                if batch_idx % 10 == 0:
                    print(
                        'Collaborative traing : Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(n, model_type_list[n], epoch + 1,
                                batch_idx * len(images),
                                len(trainloader.dataset),
                                100. * batch_idx / len(trainloader),
                                loss.item()))
                train_batch_loss[n].append(loss.item())
                torch.save(
                    model.state_dict(),
                    'Src/CollaborativeModel/LocalModel{}Type{}.pkl'.format(
                        n, model_type_list[n]))
        for index in range(len(train_batch_loss)):
            loss_avg = sum(train_batch_loss[index]) / len(
                train_batch_loss[index])
            train_loss[index].append(loss_avg)

    plt.figure()
    for index in range(len(train_loss)):
        plt.plot(range(len(train_loss[index])), train_loss[index])
    plt.title('collaborative_train_losses')
    plt.xlabel('epoches')
    plt.ylabel('Train loss')
    plt.savefig('Src/Figure/collaborative_train_losses.png')
    plt.show()
    print('End Public Training')
Example #26
0
def private_dataset_train(args):
    device = 'cuda' if args.gpu else 'cpu'
    # 用于初始化模型的部分
    # 获得FEMNIST数据集!
    train_dataset, test_dataset = get_private_dataset_balanced(args)
    user_groups = FEMNIST_iid(train_dataset, args.user_number)

    models = {
        "2_layer_CNN": CNN_2layer_fc_model,  # 字典的函数类型
        "3_layer_CNN": CNN_3layer_fc_model
    }
    modelsindex = ["2_layer_CNN", "3_layer_CNN"]

    if args.new_private_training:
        model_list, model_type_list = get_model_list(args.initialurl,
                                                     modelsindex, models)
        #model_list,model_type_list = get_model_list('Src/EmptyModel',modelsindex,models)
    else:
        #model_list,model_type_list = get_model_list(args.privateurl,modelsindex,models)
        model_list, model_type_list = get_model_list('Src/ModelNonIdFemnist',
                                                     modelsindex, models)

    private_model_private_dataset_train_losses = []
    private_model_private_dataset_validation_losses = []
    for n, model in enumerate(model_list):
        print('train Local Model {} on Private Dataset'.format(n))
        model.to(device)
        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=0.5)
        elif args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=1e-4)
        trainloader = DataLoader(DatasetSplit(train_dataset,
                                              list(user_groups[n])),
                                 batch_size=32,
                                 shuffle=True)
        testloader = DataLoader(test_dataset, batch_size=128, shuffle=True)
        criterion = nn.NLLLoss().to(device)
        train_epoch_losses = []
        validation_epoch_losses = []
        print('Begin Private Training')
        earlyStopping = EarlyStopping(
            patience=5,
            verbose=True,
            path='Src/ModelNonIdFemnist/LocalModel{}Type{}.pkl'.format(
                n, model_type_list[n], args.privateepoch))
        for epoch in range(args.privateepoch):
            model.train()
            train_batch_losses = []
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                if batch_idx % 5 == 0:
                    print(
                        'Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(n, model_type_list[n], epoch + 1,
                                batch_idx * len(images),
                                len(trainloader.dataset),
                                100. * batch_idx / len(trainloader),
                                loss.item()))
                train_batch_losses.append(loss.item())
            loss_avg = sum(train_batch_losses) / len(train_batch_losses)
            train_epoch_losses.append(loss_avg)

            model.eval()
            val_batch_losses = []
            for batch_idx, (images, labels) in enumerate(testloader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                if batch_idx % 5 == 0:
                    print(
                        'Local Model {} Type {} Val Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(n, model_type_list[n], epoch + 1,
                                batch_idx * len(images),
                                len(testloader.dataset),
                                100. * batch_idx / len(testloader),
                                loss.item()))
                val_batch_losses.append(loss.item())
            loss_avg = sum(val_batch_losses) / len(val_batch_losses)
            validation_epoch_losses.append(loss_avg)
            earlyStopping(loss_avg, model)
            if earlyStopping.early_stop:
                print("Early stopping")
                break
        private_model_private_dataset_train_losses.append(train_epoch_losses)
        private_model_private_dataset_validation_losses.append(
            validation_epoch_losses)

    plt.figure()
    for i, val in enumerate(private_model_private_dataset_train_losses):
        print(val)
        plt.plot(range(len(val)), val, label='model :' + str(i))
    plt.legend(loc='best')
    plt.title('private_model_private_non_iid_dataset_train_demo_losses')
    plt.xlabel('epoches')
    plt.ylabel('Train loss')
    x_major_locator = MultipleLocator(1)  # 把x轴的刻度间隔设置为1,并存在变量里
    ax = plt.gca()  # ax为两条坐标轴的实例
    ax.xaxis.set_major_locator(x_major_locator)  # 把x轴的主刻度设置为1的倍数
    plt.xlim(0, args.privateepoch)
    plt.savefig(
        'Src/Figure/private_model_private_non_iid_dataset_train_demo_losses.png'
    )
    plt.show()

    plt.figure()
    for i, val in enumerate(private_model_private_dataset_validation_losses):
        print(val)
        plt.plot(range(len(val)), val, label='model :' + str(i))
    plt.legend(loc='best')
    plt.title('private_model_private_non_iid_dataset_validation_demo_losses')
    plt.xlabel('epoches')
    plt.ylabel('Validation loss')
    x_major_locator = MultipleLocator(1)  # 把x轴的刻度间隔设置为1,并存在变量里
    ax = plt.gca()  # ax为两条坐标轴的实例
    ax.xaxis.set_major_locator(x_major_locator)  # 把x轴的主刻度设置为1的倍数
    plt.xlim(0, args.privateepoch)
    plt.savefig(
        'Src/Figure/private_model_private_non_iid_dataset_validation_demo_losses.png'
    )
    plt.show()

    print('End Private Training')
Example #27
0
data_loader = get_data_loader_folder(opts.input_folder,
                                     1,
                                     False,
                                     new_size=config['crop_image_height'],
                                     crop=False)

# config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

last_model_name = get_model_list(opts.checkpoint, "gen")
state_dict = torch.load(last_model_name)
trainer.gen_a.load_state_dict(state_dict['a'])
trainer.gen_b.load_state_dict(state_dict['b'])

last_model_name = get_model_list(opts.checkpoint, "submodel")
state_dict = torch.load(last_model_name)
trainer.a2b.load_state_dict(state_dict['a2b'])
trainer.b2a.load_state_dict(state_dict['b2a'])

trainer.cuda()
trainer.eval()
encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode  # encode function
decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode  # decode function

get_style = trainer.b2a
Example #28
0
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory = os.path.join(output_directory, 'checkpoints')
    os.makedirs(checkpoint_directory, exist_ok=True)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    dataloader = DataLoader(
        WCDataset(config['dataset_path']),
        batch_size=config['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=config['num_workers'],
    )

    if opts.resume:  # opts.resume=False
        last_model_name = get_model_list(checkpoint_directory)
        iteration = int(last_model_name[:-4])
        net = get_net(last_model_name, dataloader.dataset.class_num)
        optimizer = get_optimizer(net, config)
        scheducer = get_scheducer(optimizer, config, iteration)
        print('Resume from iteration %d' % iteration)
    else:
        iteration = 0
        net = get_net(config['weight_path'], dataloader.dataset.class_num)
        optimizer = get_optimizer(net, config)
        scheducer = get_scheducer(optimizer, config)

    criterion = AngleLoss()
    # criterion = nn.CrossEntropyLoss()
    train(net, dataloader, criterion, optimizer, scheducer, config,
          train_writer, checkpoint_directory, iteration)
Example #29
0
def collaborative_private_model_femnist_train(args):
    device = 'cuda' if args.gpu else 'cpu'
    # 用于初始化模型的部分
    # 获得FEMNIST数据集!
    train_dataset, test_dataset = get_private_dataset_balanced(args)
    user_groups = FEMNIST_iid(train_dataset, args.user_number)

    models = {
        "2_layer_CNN": CNN_2layer_fc_model,  # 字典的函数类型
        "3_layer_CNN": CNN_3layer_fc_model
    }
    modelsindex = ["2_layer_CNN", "3_layer_CNN"]
    model_list, model_type_list = get_model_list(args.Collaborativeurl,
                                                 modelsindex, models)

    print('Begin Private Training')

    private_model_private_dataset_train_losses = []
    for n, model in enumerate(model_list):
        print('train Local Model {} on Private Dataset'.format(n))
        model.to(device)
        model.train()
        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=0.5)
        elif args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=1e-4)
        trainloader = DataLoader(DatasetSplit(train_dataset,
                                              list(user_groups[n])),
                                 batch_size=5,
                                 shuffle=True)
        criterion = nn.NLLLoss().to(device)
        train_epoch_losses = []
        for epoch in range(args.Communication_private_epoch):
            train_batch_losses = []
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                if batch_idx % 5 == 0:
                    print(
                        'Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(n, model_type_list[n], epoch + 1,
                                batch_idx * len(images),
                                len(trainloader.dataset),
                                100. * batch_idx / len(trainloader),
                                loss.item()))
                train_batch_losses.append(loss.item())
            loss_avg = sum(train_batch_losses) / len(train_batch_losses)
            train_epoch_losses.append(loss_avg)
        torch.save(
            model.state_dict(),
            'Src/CollaborativeModel/LocalModel{}Type{}.pkl'.format(
                n, model_type_list[n], args.epoch))
        private_model_private_dataset_train_losses.append(train_epoch_losses)
    plt.figure()
    for i, val in enumerate(private_model_private_dataset_train_losses):
        print(val)
        plt.plot(range(len(val)), val)
    plt.title('collaborative_private_model_private_dataset_train_losses')
    plt.xlabel('epoches')
    plt.ylabel('Train loss')
    plt.savefig(
        'Src/Figure/collaborative_private_model_private_dataset_train_losses.png'
    )
    plt.show()
    print('End Private Training')
Example #30
0
                        required=True)
    parser.add_argument('-s',
                        '--save',
                        help='Path to save models and stats',
                        default=None,
                        required=True)
    args = parser.parse_args()

    data_path = args.data
    save_path = args.save
    train_mode = args.mode
    num_epochs = args.epoch
    batch_size = args.batch

    start_time = str(int(time.time()))
    list_of_models = get_model_list()
    models_to_test = ['alexnet']

    use_gpu = torch.cuda.is_available()
    plot_colors = get_palet(len(models_to_test))
    accfinal = 0

    stats_file = save_path + '/' + start_time + '_' + os.path.basename(
        data_path) + '_' + args.mode + '_stats.csv'

    number_classes = len(glob.glob(data_path + '/train/*'))

    stats = []

    print('Data:', data_path)
    print('Mode:', train_mode)
Example #31
0
def continue_train_models(args):
    device = 'cuda' if args.gpu else 'cpu'
    # 用于初始化模型的部分
    train_dataset, test_dataset = get_public_dataset(args)
    models = {
        "2_layer_CNN": CNN_2layer_fc_model,  # 字典的函数类型
        "3_layer_CNN": CNN_3layer_fc_model
    }
    modelsindex = ["2_layer_CNN", "3_layer_CNN"]
    model_list, model_type_list = get_model_list(args.initialurl, modelsindex,
                                                 models)

    private_model_public_dataset_train_losses = []
    for n, model in enumerate(model_list):
        print('Continue train Local Model {}'.format(n))
        model.to(device)
        model.train()
        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=0.5)
        elif args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=1e-4)
        trainloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
        criterion = nn.NLLLoss().to(device)
        train_epoch_losses = []
        print('Begin Public Training')
        for epoch in range(args.continue_epoch):
            train_batch_losses = []
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                if batch_idx % 50 == 0:
                    print(
                        'Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(n, model_type_list[n], epoch + 1,
                                batch_idx * len(images),
                                len(trainloader.dataset),
                                100. * batch_idx / len(trainloader),
                                loss.item()))
                train_batch_losses.append(loss.item())
            loss_avg = sum(train_batch_losses) / len(train_batch_losses)
            train_epoch_losses.append(loss_avg)

        torch.save(
            model.state_dict(),
            'Src/Model/LocalModel{}Type{}Epoch{}.pkl'.format(
                n, model_type_list[n], args.epoch))
        private_model_public_dataset_train_losses.append(train_epoch_losses)

    plt.figure()
    for i, val in enumerate(private_model_public_dataset_train_losses):
        print(val)
        plt.plot(range(len(val)), val)
    plt.xlabel('epoches')
    plt.ylabel('Train loss')
    plt.savefig(
        'Src/Figure/private_model_public_dataset_train_continue_losses.png')
    plt.show()
    print('End Public Training')