def train(self, train_set, valid_set=None, ref_set=None):
     print('==> Start training {}'.format(self.reconstruction_path))
     self.train_flag = True
     if self.early_stop:
         valid_loader = DataLoader(valid_set,
                                   batch_size=self.test_batch_size,
                                   shuffle=True,
                                   num_workers=2)
     for epoch in range(self.start_epoch, self.start_epoch + self.epochs):
         permutated_idx = np.random.permutation(ref_set.__len__())
         ref_set = Subset(ref_set, permutated_idx)
         train_ref_set = data.DoubleDataset(train_set, ref_set)
         train_ref_loader = DataLoader(train_ref_set,
                                       batch_size=self.train_batch_size,
                                       shuffle=True,
                                       num_workers=2)
         if self.train_flag:
             self.train_epoch(train_ref_loader, epoch)
             if self.use_scheduler:
                 self.scheduler_enc.step()
                 self.scheduler_dec.step()
             if self.early_stop:
                 self.inference(valid_loader, epoch, type='valid')
         else:
             break
예제 #2
0
    def train(self, train_set, valid_set=None, ref_set=None):
        print('==> Start training {}'.format(self.reconstruction_path))

        if self.resume:
            print('==> Resuming from checkpoint..')
            try:
                last_epoch = np.load(os.path.join(self.reconstruction_path, 'last_epoch.npy'))
                # print(last_epoch)
                self.load(last_epoch)
            except FileNotFoundError:
                print('There is no pre-trained model; Train model from scratch')

        self.train_flag = True
        if self.early_stop:
            valid_loader = DataLoader(valid_set, batch_size=self.test_batch_size, shuffle=True, num_workers=2)

        # if ref_set is not None:
            # ref_set = self.transform(ref_set)
        for epoch in range(self.start_epoch, self.start_epoch + self.epochs):
            permutated_idx = np.random.permutation(ref_set.__len__())
            ref_set = Subset(ref_set, permutated_idx)
            train_ref_set = data.DoubleDataset(train_set, ref_set)
            train_ref_loader = DataLoader(train_ref_set, batch_size=self.train_batch_size, shuffle=True, num_workers=2)
            if self.train_flag:
                self.train_epoch(train_ref_loader, epoch)

                if self.early_stop:
                    self.inference(valid_loader, epoch, inference_type='valid')
                else:

                    if epoch > self.disentanglement_start_epoch - 1:
                        for encoder_name in self.encoder_name_list:
                            self.encoders_opt_scheduler[encoder_name].step()
                            self.class_discs_opt_scheduler[encoder_name].step()
                            self.membership_discs_opt_scheduler[encoder_name].step()
                        self.decoder_opt_scheduler.step()
                        if self.weights['real_fake'] > 0:
                            self.rf_disc_opt_scheduler.step()

                    if (epoch+1) % self.save_step_size == 0:
                        print('Save at {}'.format(epoch+1))

                        state = {
                            # 'best_valid_loss': loss,
                            'epoch': epoch,
                        }

                        for encoder_name in self.encoder_name_list:
                            state['enc_' + encoder_name] = self.encoders[encoder_name].state_dict()
                            state['dec'] = self.decoder.state_dict()
                            state['class_disc_' + encoder_name] = self.class_discs[encoder_name].state_dict()
                            state['membership_disc_' + encoder_name] = self.membership_discs[encoder_name].state_dict()

                        torch.save(state, os.path.join(self.reconstruction_path, 'ckpt{:03d}.pth'.format(epoch+1)))
                        # self.best_valid_loss = loss
                        # self.early_stop_count = 0
                        # self.best_class_acc_dict = self.class_acc_dict
                        # self.best_membership_acc_dict = self.membership_acc_dict

                        np.save(os.path.join(self.reconstruction_path, 'class_acc{:03d}.npy'.format(epoch+1)), self.class_acc_dict)
                        np.save(os.path.join(self.reconstruction_path, 'membership_acc{:03d}.npy'.format(epoch+1)), self.membership_acc_dict)
                        np.save(os.path.join(self.reconstruction_path, 'last_epoch.npy'), epoch)
                        # vutils.save_image(recons, os.path.join(self.reconstruction_path, '{}.png'.format(epoch)), nrow=10)


                # if self.early_stop:
                #     val_loss = self.inference(valid_loader, epoch, inference_type='valid')
                #     for encoder_name in self.encoder_name_list:
                #         self.encoders_opt_scheduler[encoder_name].step(val_loss)
                #         self.class_discs_opt_scheduler[encoder_name].step(val_loss)
                #         self.membership_discs_opt_scheduler[encoder_name].step(val_loss)
                #     self.decoder_opt_scheduler.step(val_loss)
                #     self.rf_disc_opt_scheduler.step(val_loss)

                # else:
                #     if (epoch+1) % self.save_step_size == 0:
                #         print('Save at {}'.format(epoch+1))

                #         state = {
                #             # 'best_valid_loss': loss,
                #             'epoch': epoch,
                #         }

                #         for encoder_name in self.encoder_name_list:
                #             state['enc_' + encoder_name] = self.encoders[encoder_name].state_dict()
                #             state['dec'] = self.decoder.state_dict()
                #             state['class_disc_' + encoder_name] = self.class_discs[encoder_name].state_dict()
                #             state['membership_disc_' + encoder_name] = self.membership_discs[encoder_name].state_dict()

                #         torch.save(state, os.path.join(self.reconstruction_path, 'ckpt{:03d}.pth'.format(epoch+1)))
                #         # self.best_valid_loss = loss
                #         # self.early_stop_count = 0
                #         # self.best_class_acc_dict = self.class_acc_dict
                #         # self.best_membership_acc_dict = self.membership_acc_dict

                #         np.save(os.path.join(self.reconstruction_path, 'class_acc{:03d}.npy'.format(epoch+1)), self.class_acc_dict)
                #         np.save(os.path.join(self.reconstruction_path, 'membership_acc{:03d}.npy'.format(epoch+1)), self.membership_acc_dict)
                #         np.save(os.path.join(self.reconstruction_path, 'last_epoch.npy'), epoch)
                #         # vutils.save_image(recons, os.path.join(self.reconstruction_path, '{}.png'.format(epoch)), nrow=10)

            else:
                break