def train(self, train_set, valid_set=None, ref_set=None): print('==> Start training {}'.format(self.reconstruction_path)) self.train_flag = True if self.early_stop: valid_loader = DataLoader(valid_set, batch_size=self.test_batch_size, shuffle=True, num_workers=2) for epoch in range(self.start_epoch, self.start_epoch + self.epochs): permutated_idx = np.random.permutation(ref_set.__len__()) ref_set = Subset(ref_set, permutated_idx) train_ref_set = data.DoubleDataset(train_set, ref_set) train_ref_loader = DataLoader(train_ref_set, batch_size=self.train_batch_size, shuffle=True, num_workers=2) if self.train_flag: self.train_epoch(train_ref_loader, epoch) if self.use_scheduler: self.scheduler_enc.step() self.scheduler_dec.step() if self.early_stop: self.inference(valid_loader, epoch, type='valid') else: break
def train(self, train_set, valid_set=None, ref_set=None): print('==> Start training {}'.format(self.reconstruction_path)) if self.resume: print('==> Resuming from checkpoint..') try: last_epoch = np.load(os.path.join(self.reconstruction_path, 'last_epoch.npy')) # print(last_epoch) self.load(last_epoch) except FileNotFoundError: print('There is no pre-trained model; Train model from scratch') self.train_flag = True if self.early_stop: valid_loader = DataLoader(valid_set, batch_size=self.test_batch_size, shuffle=True, num_workers=2) # if ref_set is not None: # ref_set = self.transform(ref_set) for epoch in range(self.start_epoch, self.start_epoch + self.epochs): permutated_idx = np.random.permutation(ref_set.__len__()) ref_set = Subset(ref_set, permutated_idx) train_ref_set = data.DoubleDataset(train_set, ref_set) train_ref_loader = DataLoader(train_ref_set, batch_size=self.train_batch_size, shuffle=True, num_workers=2) if self.train_flag: self.train_epoch(train_ref_loader, epoch) if self.early_stop: self.inference(valid_loader, epoch, inference_type='valid') else: if epoch > self.disentanglement_start_epoch - 1: for encoder_name in self.encoder_name_list: self.encoders_opt_scheduler[encoder_name].step() self.class_discs_opt_scheduler[encoder_name].step() self.membership_discs_opt_scheduler[encoder_name].step() self.decoder_opt_scheduler.step() if self.weights['real_fake'] > 0: self.rf_disc_opt_scheduler.step() if (epoch+1) % self.save_step_size == 0: print('Save at {}'.format(epoch+1)) state = { # 'best_valid_loss': loss, 'epoch': epoch, } for encoder_name in self.encoder_name_list: state['enc_' + encoder_name] = self.encoders[encoder_name].state_dict() state['dec'] = self.decoder.state_dict() state['class_disc_' + encoder_name] = self.class_discs[encoder_name].state_dict() state['membership_disc_' + encoder_name] = self.membership_discs[encoder_name].state_dict() torch.save(state, os.path.join(self.reconstruction_path, 'ckpt{:03d}.pth'.format(epoch+1))) # self.best_valid_loss = loss # self.early_stop_count = 0 # self.best_class_acc_dict = self.class_acc_dict # self.best_membership_acc_dict = self.membership_acc_dict np.save(os.path.join(self.reconstruction_path, 'class_acc{:03d}.npy'.format(epoch+1)), self.class_acc_dict) np.save(os.path.join(self.reconstruction_path, 'membership_acc{:03d}.npy'.format(epoch+1)), self.membership_acc_dict) np.save(os.path.join(self.reconstruction_path, 'last_epoch.npy'), epoch) # vutils.save_image(recons, os.path.join(self.reconstruction_path, '{}.png'.format(epoch)), nrow=10) # if self.early_stop: # val_loss = self.inference(valid_loader, epoch, inference_type='valid') # for encoder_name in self.encoder_name_list: # self.encoders_opt_scheduler[encoder_name].step(val_loss) # self.class_discs_opt_scheduler[encoder_name].step(val_loss) # self.membership_discs_opt_scheduler[encoder_name].step(val_loss) # self.decoder_opt_scheduler.step(val_loss) # self.rf_disc_opt_scheduler.step(val_loss) # else: # if (epoch+1) % self.save_step_size == 0: # print('Save at {}'.format(epoch+1)) # state = { # # 'best_valid_loss': loss, # 'epoch': epoch, # } # for encoder_name in self.encoder_name_list: # state['enc_' + encoder_name] = self.encoders[encoder_name].state_dict() # state['dec'] = self.decoder.state_dict() # state['class_disc_' + encoder_name] = self.class_discs[encoder_name].state_dict() # state['membership_disc_' + encoder_name] = self.membership_discs[encoder_name].state_dict() # torch.save(state, os.path.join(self.reconstruction_path, 'ckpt{:03d}.pth'.format(epoch+1))) # # self.best_valid_loss = loss # # self.early_stop_count = 0 # # self.best_class_acc_dict = self.class_acc_dict # # self.best_membership_acc_dict = self.membership_acc_dict # np.save(os.path.join(self.reconstruction_path, 'class_acc{:03d}.npy'.format(epoch+1)), self.class_acc_dict) # np.save(os.path.join(self.reconstruction_path, 'membership_acc{:03d}.npy'.format(epoch+1)), self.membership_acc_dict) # np.save(os.path.join(self.reconstruction_path, 'last_epoch.npy'), epoch) # # vutils.save_image(recons, os.path.join(self.reconstruction_path, '{}.png'.format(epoch)), nrow=10) else: break