def __init__(self, config): self.config = config self.flag_gan = False self.train_count = 0 self.pretraining_step_size = self.config.pretraining_step_size self.batch_size = self.config.batch_size self.logger = self.set_logger() # define dataloader self.dataset = NoteDataset(self.config.root_path, self.config) self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=1, pin_memory=self.config.pin_memory, collate_fn=self.make_batch) # define models ( generator and discriminator) self.generator = Model() self.discriminator = BarDiscriminator() self.discriminator_feature = BarFeatureDiscriminator() self.z_discriminator_phrase = PhraseZDiscriminator() self.z_discriminator_bar = BarZDiscriminator() # define loss self.loss_generator = Loss().cuda() self.loss_disc = DLoss().cuda() self.loss_feature_disc = DLoss().cuda() self.loss_bar = DLoss().cuda() self.loss_phrase = DLoss().cuda() # define lr self.lr_generator = self.config.learning_rate self.lr_discriminator = self.config.learning_rate self.lr_discriminator_feature = self.config.learning_rate self.lr_Zdiscriminator_bar = self.config.learning_rate self.lr_Zdiscriminator_phrase = self.config.learning_rate # define optimizer self.opt_generator = torch.optim.Adam(self.generator.parameters(), lr=self.lr_generator) self.opt_discriminator = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr_discriminator) self.opt_discriminator_feature = torch.optim.Adam(self.discriminator_feature.parameters(), lr=self.lr_discriminator_feature) self.opt_Zdiscriminator_bar = torch.optim.Adam(self.z_discriminator_bar.parameters(), lr=self.lr_Zdiscriminator_bar) self.opt_Zdiscriminator_phrase = torch.optim.Adam(self.z_discriminator_phrase.parameters(), lr=self.lr_Zdiscriminator_phrase) # define optimize scheduler self.scheduler_generator = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_generator, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_discriminator, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator_feature = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_discriminator_feature, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_bar = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_Zdiscriminator_bar, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_phrase = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_Zdiscriminator_phrase, mode='min', factor=0.8, cooldown=6) # initialize counter self.iteration = 0 self.epoch = 0 self.manual_seed = random.randint(1, 10000) torch.manual_seed(self.manual_seed) torch.cuda.manual_seed_all(self.manual_seed) random.seed(self.manual_seed) print("seed: ", self.manual_seed) # cuda setting gpu_list = list(range(self.config.gpu_cnt)) self.generator = nn.DataParallel(self.generator, device_ids=gpu_list) self.discriminator = nn.DataParallel(self.discriminator, device_ids=gpu_list) self.discriminator_feature = nn.DataParallel(self.discriminator_feature, device_ids=gpu_list) self.z_discriminator_bar = nn.DataParallel(self.z_discriminator_bar, device_ids=gpu_list) self.z_discriminator_phrase = nn.DataParallel(self.z_discriminator_phrase, device_ids=gpu_list) self.generator = self.generator.cuda() self.discriminator = self.discriminator.cuda() self.discriminator_feature = self.discriminator_feature.cuda() self.z_discriminator_bar = self.z_discriminator_bar.cuda() self.z_discriminator_phrase = self.z_discriminator_phrase.cuda() # Model Loading from the latest checkpoint if not found start from scratch. self.load_checkpoint(self.config.checkpoint_file) # Summary Writer self.summary_writer = SummaryWriter(log_dir=os.path.join(self.config.root_path, self.config.summary_dir), comment='BarGen') print('Number of generator parameters: {}'.format(sum([p.data.nelement() for p in self.generator.parameters()]))) print('Number of discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.discriminator.parameters()]))) print('Number of feature discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.discriminator_feature.parameters()]))) print('Number of barZ discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.z_discriminator_bar.parameters()]))) print('Number of phraseZ discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.z_discriminator_phrase.parameters()])))
def __init__(self, config): hvd.init() torch.cuda.set_device(hvd.local_rank()) self.config = config self.flag_gan = False self.train_count = 0 self.pretraining_step_size = self.config.pretraining_step_size self.batch_size = self.config.batch_size self.logger = self.set_logger() # define dataloader self.dataset = NoteDataset(self.config.root_path, self.config) train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=hvd.size(), rank=hvd.rank()) self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, num_workers=1, sampler=train_sampler, collate_fn=make_batch) # define models ( generator and discriminator) self.generator = Model() self.discriminator = BarDiscriminator() self.discriminator_feature = BarFeatureDiscriminator() self.z_discriminator_phrase = PhraseZDiscriminator() self.z_discriminator_bar = BarZDiscriminator() self.generator = self.generator.cuda() self.discriminator = self.discriminator.cuda() self.discriminator_feature = self.discriminator_feature.cuda() self.z_discriminator_bar = self.z_discriminator_bar.cuda() self.z_discriminator_phrase = self.z_discriminator_phrase.cuda() # define loss self.loss_generator = Loss().cuda() self.loss_disc = DLoss().cuda() self.loss_feature_disc = DLoss().cuda() self.loss_bar = DLoss().cuda() self.loss_phrase = DLoss().cuda() # define lr self.lr_generator = self.config.learning_rate self.lr_discriminator = self.config.learning_rate self.lr_discriminator_feature = self.config.learning_rate self.lr_Zdiscriminator_bar = self.config.learning_rate self.lr_Zdiscriminator_phrase = self.config.learning_rate # define optimizer self.opt_generator = torch.optim.Adam(self.generator.parameters(), lr=self.lr_generator) self.opt_discriminator = torch.optim.Adam( self.discriminator.parameters(), lr=self.lr_discriminator) self.opt_discriminator_feature = torch.optim.Adam( self.discriminator_feature.parameters(), lr=self.lr_discriminator_feature) self.opt_Zdiscriminator_bar = torch.optim.Adam( self.z_discriminator_bar.parameters(), lr=self.lr_Zdiscriminator_bar) self.opt_Zdiscriminator_phrase = torch.optim.Adam( self.z_discriminator_phrase.parameters(), lr=self.lr_Zdiscriminator_phrase) self.opt_generator = hvd.DistributedOptimizer( self.generator.parameters(), named_parameters=self.opt_generator) self.opt_discriminator = hvd.DistributedOptimizer( self.discriminator.parameters(), named_parameters=self.opt_discriminator) self.opt_discriminator_feature = hvd.DistributedOptimizer( self.discriminator_feature.parameters(), named_parameters=self.opt_discriminator_feature) self.opt_Zdiscriminator_bar = hvd.DistributedOptimizer( self.z_discriminator_bar.parameters(), named_parameters=self.opt_Zdiscriminator_bar) self.opt_Zdiscriminator_phrase = hvd.DistributedOptimizer( self.z_discriminator_phrase.parameters(), named_parameters=self.opt_Zdiscriminator_phrase) # define optimize scheduler self.scheduler_generator = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_generator, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_discriminator, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator_feature = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_discriminator_feature, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_bar = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_Zdiscriminator_bar, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_phrase = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_Zdiscriminator_phrase, mode='min', factor=0.8, cooldown=6) # initialize counter self.iteration = 0 self.epoch = 0 self.manual_seed = random.randint(1, 10000) torch.manual_seed(self.manual_seed) torch.cuda.manual_seed_all(self.manual_seed) random.seed(self.manual_seed) print("seed: ", self.manual_seed) hvd.broadcast_parameters(self.generator.state_dict(), root_rank=0) hvd.broadcast_parameters(self.discriminator.state_dict(), root_rank=0) hvd.broadcast_parameters(self.discriminator_feature.state_dict(), root_rank=0) hvd.broadcast_parameters(self.z_discriminator_bar.state_dict(), root_rank=0) hvd.broadcast_parameters(self.z_discriminator_phrase.state_dict(), root_rank=0) # Model Loading from the latest checkpoint if not found start from scratch. if hvd.rank() == 0: self.load_checkpoint(self.config.checkpoint_file) # Summary Writer self.summary_writer = SummaryWriter(log_dir=os.path.join( self.config.root_path, self.config.summary_dir), comment='BarGen') print('Number of generator parameters: {}'.format( sum([p.data.nelement() for p in self.generator.parameters()]))) print('Number of discriminator parameters: {}'.format( sum([p.data.nelement() for p in self.discriminator.parameters()]))) print('Number of feature discriminator parameters: {}'.format( sum([ p.data.nelement() for p in self.discriminator_feature.parameters() ]))) print('Number of barZ discriminator parameters: {}'.format( sum([ p.data.nelement() for p in self.z_discriminator_bar.parameters() ]))) print('Number of phraseZ discriminator parameters: {}'.format( sum([ p.data.nelement() for p in self.z_discriminator_phrase.parameters() ])))
class BarGen(object): def __init__(self, config): self.config = config self.flag_gan = False self.train_count = 0 self.pretraining_step_size = self.config.pretraining_step_size self.batch_size = self.config.batch_size self.logger = self.set_logger() # define dataloader self.dataset = NoteDataset(self.config.root_path, self.config) self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=1, pin_memory=self.config.pin_memory, collate_fn=self.make_batch) # define models ( generator and discriminator) self.generator = Model() self.discriminator = BarDiscriminator() self.discriminator_feature = BarFeatureDiscriminator() self.z_discriminator_phrase = PhraseZDiscriminator() self.z_discriminator_bar = BarZDiscriminator() # define loss self.loss_generator = Loss().cuda() self.loss_disc = DLoss().cuda() self.loss_feature_disc = DLoss().cuda() self.loss_bar = DLoss().cuda() self.loss_phrase = DLoss().cuda() # define lr self.lr_generator = self.config.learning_rate self.lr_discriminator = self.config.learning_rate self.lr_discriminator_feature = self.config.learning_rate self.lr_Zdiscriminator_bar = self.config.learning_rate self.lr_Zdiscriminator_phrase = self.config.learning_rate # define optimizer self.opt_generator = torch.optim.Adam(self.generator.parameters(), lr=self.lr_generator) self.opt_discriminator = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr_discriminator) self.opt_discriminator_feature = torch.optim.Adam(self.discriminator_feature.parameters(), lr=self.lr_discriminator_feature) self.opt_Zdiscriminator_bar = torch.optim.Adam(self.z_discriminator_bar.parameters(), lr=self.lr_Zdiscriminator_bar) self.opt_Zdiscriminator_phrase = torch.optim.Adam(self.z_discriminator_phrase.parameters(), lr=self.lr_Zdiscriminator_phrase) # define optimize scheduler self.scheduler_generator = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_generator, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_discriminator, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator_feature = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_discriminator_feature, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_bar = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_Zdiscriminator_bar, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_phrase = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt_Zdiscriminator_phrase, mode='min', factor=0.8, cooldown=6) # initialize counter self.iteration = 0 self.epoch = 0 self.manual_seed = random.randint(1, 10000) torch.manual_seed(self.manual_seed) torch.cuda.manual_seed_all(self.manual_seed) random.seed(self.manual_seed) print("seed: ", self.manual_seed) # cuda setting gpu_list = list(range(self.config.gpu_cnt)) self.generator = nn.DataParallel(self.generator, device_ids=gpu_list) self.discriminator = nn.DataParallel(self.discriminator, device_ids=gpu_list) self.discriminator_feature = nn.DataParallel(self.discriminator_feature, device_ids=gpu_list) self.z_discriminator_bar = nn.DataParallel(self.z_discriminator_bar, device_ids=gpu_list) self.z_discriminator_phrase = nn.DataParallel(self.z_discriminator_phrase, device_ids=gpu_list) self.generator = self.generator.cuda() self.discriminator = self.discriminator.cuda() self.discriminator_feature = self.discriminator_feature.cuda() self.z_discriminator_bar = self.z_discriminator_bar.cuda() self.z_discriminator_phrase = self.z_discriminator_phrase.cuda() # Model Loading from the latest checkpoint if not found start from scratch. self.load_checkpoint(self.config.checkpoint_file) # Summary Writer self.summary_writer = SummaryWriter(log_dir=os.path.join(self.config.root_path, self.config.summary_dir), comment='BarGen') print('Number of generator parameters: {}'.format(sum([p.data.nelement() for p in self.generator.parameters()]))) print('Number of discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.discriminator.parameters()]))) print('Number of feature discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.discriminator_feature.parameters()]))) print('Number of barZ discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.z_discriminator_bar.parameters()]))) print('Number of phraseZ discriminator parameters: {}'.format(sum([p.data.nelement() for p in self.z_discriminator_phrase.parameters()]))) def get_lr(self, optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def set_logger(self): logger = logging.getLogger() logger.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") file_handler = logging.FileHandler(filename="train_epoch.log") file_handler.setLevel(logging.WARNING) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def make_batch(self, samples): note = np.concatenate([sample['note'] for sample in samples], axis=0) pre_note = np.concatenate([sample['pre_note'] for sample in samples], axis=0) pre_phrase = np.concatenate([sample['pre_phrase'] for sample in samples], axis=0) position = np.concatenate([sample['position'] for sample in samples], axis=0) return tuple([torch.tensor(note, dtype=torch.float), torch.tensor(pre_note, dtype=torch.float), torch.tensor(pre_phrase, dtype=torch.float), torch.tensor(position, dtype=torch.long)]) def free(self, module: nn.Module): for p in module.parameters(): p.requires_grad = True def frozen(self, module: nn.Module): for p in module.parameters(): p.requires_grad = False def load_checkpoint(self, file_name): filename = os.path.join(self.config.root_path, self.config.checkpoint_dir, file_name) try: print("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.generator.load_state_dict(checkpoint['generator_state_dict']) self.opt_generator.load_state_dict(checkpoint['generator_optimizer']) self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) self.opt_discriminator.load_state_dict(checkpoint['disc_optimizer']) self.discriminator_feature.load_state_dict(checkpoint['discriminator_feature_state_dict']) self.opt_discriminator_feature.load_state_dict(checkpoint['disc_feature_optimizer']) self.z_discriminator_bar.load_state_dict(checkpoint['z_discriminator_bar_state_dict']) self.opt_Zdiscriminator_bar.load_state_dict(checkpoint['opt_Zdiscriminator_bar_optimizer']) self.z_discriminator_phrase.load_state_dict(checkpoint['z_discriminator_phrase_state_dict']) self.opt_Zdiscriminator_phrase.load_state_dict(checkpoint['opt_Zdiscriminator_phrase_optimizer']) except OSError as e: print("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) print("**First time to train**") def save_checkpoint(self, file_name, epoch): tmp_name = os.path.join(self.config.root_path, self.config.checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) # file_name = os.path.join(self.config.root_path, self.config.checkpoint_dir, file_name) state = { 'generator_state_dict': self.generator.state_dict(), 'generator_optimizer': self.opt_generator.state_dict(), 'discriminator_state_dict': self.discriminator.state_dict(), 'disc_optimizer': self.opt_discriminator.state_dict(), 'discriminator_feature_state_dict': self.discriminator_feature.state_dict(), 'disc_feature_optimizer': self.opt_discriminator_feature.state_dict(), 'z_discriminator_bar_state_dict': self.z_discriminator_bar.state_dict(), 'opt_Zdiscriminator_bar_optimizer': self.opt_Zdiscriminator_bar.state_dict(), 'z_discriminator_phrase_state_dict': self.z_discriminator_phrase.state_dict(), 'opt_Zdiscriminator_phrase_optimizer': self.opt_Zdiscriminator_phrase.state_dict(), } torch.save(state, tmp_name) shutil.copyfile(tmp_name, os.path.join(self.config.root_path, self.config.checkpoint_dir, 'checkpoint.pth.tar')) def run(self): try: self.train() except KeyboardInterrupt: print("You have entered CTRL+C.. Wait to finalize") def record_image(self, image_sample, origin_image, outputs): self.summary_writer.add_image("train/sample 1", image_sample[0].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/sample 2", image_sample[1].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/sample 3", image_sample[2].reshape(1, 96, 60), self.epoch) image_sample = torch.gt(image_sample, 0.3).type('torch.cuda.FloatTensor') self.summary_writer.add_image("train/sample_binarization 1", image_sample[0].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/sample_binarization 2", image_sample[1].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/sample_binarization 3", image_sample[2].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/origin 1", origin_image[0].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/origin 2", origin_image[1].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("train/origin 3", origin_image[2].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("eval/generated 1", outputs[0].reshape(1, 96 * 4, 60), self.epoch) self.summary_writer.add_image("eval/generated 2", outputs[1].reshape(1, 96 * 4, 60), self.epoch) def train(self): for _ in range(self.config.epoch): self.epoch += 1 self.train_epoch() if self.epoch > self.pretraining_step_size + 20: self.save_checkpoint(self.config.checkpoint_file, self.epoch) def train_epoch(self): if self.epoch > self.pretraining_step_size: self.train_count += 1 tqdm_batch = tqdm(self.dataloader, total=self.dataset.num_iterations, desc="epoch-{}".format(self.epoch)) image_sample = None origin_image = None Tensor = torch.cuda.FloatTensor avg_generator_loss = AverageMeter() avg_discriminator_loss = AverageMeter() avg_feature_discriminator_loss = AverageMeter() avg_barZ_disc_loss = AverageMeter() avg_phraseZ_disc_loss = AverageMeter() for curr_it, (note, pre_note, pre_phrase, position) in enumerate(tqdm_batch): self.iteration += 1 note = note.cuda(async=self.config.async_loading) pre_note = pre_note.cuda(async=self.config.async_loading) pre_phrase = pre_phrase.cuda(async=self.config.async_loading) position = position.cuda(async=self.config.async_loading) note = Variable(note) pre_note = Variable(pre_note) pre_phrase = Variable(pre_phrase) position = Variable(position) origin_image = note valid_target = Variable(Tensor(note.size(0)).fill_(1.0), requires_grad=False) fake_target = Variable(Tensor(note.size(0)).fill_(0.0), requires_grad=False) if self.epoch <= self.pretraining_step_size: image_sample = self.train_pretrain(note, pre_note, pre_phrase, position, avg_generator_loss) else: if self.flag_gan: image_sample = self.train_gan(note, pre_note, pre_phrase, position, avg_generator_loss, avg_discriminator_loss, avg_feature_discriminator_loss, fake_target, valid_target, curr_it) else: image_sample = self.train_wae(note, pre_note, pre_phrase, position, avg_generator_loss, avg_barZ_disc_loss, avg_phraseZ_disc_loss, fake_target, valid_target, curr_it) tqdm_batch.close() if self.flag_gan and self.train_count >= 100: self.flag_gan = not self.flag_gan self.train_count = 0 elif not self.flag_gan and self.train_count >= 50: self.flag_gan = not self.flag_gan self.train_count = 0 with torch.no_grad(): self.generator.eval() self.discriminator.eval() self.discriminator_feature.eval() self.z_discriminator_bar.eval() self.z_discriminator_phrase.eval() outputs = [] pre_phrase = torch.zeros(1, 1, 384, 60, dtype=torch.float32) pre_bar = torch.zeros(1, 1, 96, 60, dtype=torch.float32) phrase_idx = [330] + [i for i in range(10 - 2, -1, -1)] for idx in range(10): bar_set = [] for _ in range(4): pre_bar, _ = self.generator(torch.randn(1, 1152, dtype=torch.float32).cuda(), pre_bar.cuda(), pre_phrase, torch.from_numpy(np.array([phrase_idx[idx]])), False) pre_bar = torch.gt(pre_bar, 0.3).type('torch.FloatTensor') # 1, 1, 96, 96 bar_set.append(np.reshape(pre_bar.numpy(), [96, 60])) pre_phrase = np.concatenate(bar_set, axis=0) outputs.append(pre_phrase) pre_phrase = torch.from_numpy(np.reshape(pre_phrase, [1, 1, 96 * 4, 60])).float().cuda() self.record_image(image_sample[:3], origin_image[:3], outputs) self.scheduler_generator.step(avg_generator_loss.val) if self.epoch > self.pretraining_step_size: self.scheduler_discriminator.step(avg_discriminator_loss.val) self.scheduler_discriminator_feature.step(avg_feature_discriminator_loss.val) self.scheduler_Zdiscriminator_bar.step(avg_barZ_disc_loss.val) self.scheduler_Zdiscriminator_phrase.step(avg_phraseZ_disc_loss.val) self.logger.warning( 'loss info - gen: {}, barZ disc: {}, phraseZ disc: {}, bar disc: {}, bar_seq disc: {}'.format( avg_generator_loss.val, avg_barZ_disc_loss.val, avg_phraseZ_disc_loss.val, avg_feature_discriminator_loss.val, avg_discriminator_loss.val) ) self.logger.warning( 'lr info - gen: {}, barZ disc: {}, phraseZ disc: {}, bar disc: {}, bar_seq disc: {}'.format( self.get_lr(self.opt_generator), self.get_lr(self.opt_Zdiscriminator_bar), self.get_lr(self.opt_Zdiscriminator_phrase), self.get_lr(self.opt_discriminator_feature), self.get_lr(self.opt_discriminator)) ) def train_pretrain(self, note, pre_note, pre_phrase, position, avg_generator_loss): self.generator.train() self.discriminator.eval() self.discriminator_feature.eval() self.z_discriminator_bar.eval() self.z_discriminator_phrase.eval() self.generator.zero_grad() #################### Generator #################### self.free(self.generator) self.frozen(self.discriminator) self.frozen(self.discriminator_feature) self.frozen(self.z_discriminator_bar) self.frozen(self.z_discriminator_phrase) gen_note, z, pre_z, phrase_feature, _ = self.generator(note, pre_note, pre_phrase, position) #### Generator Loss #### loss = self.loss_generator(gen_note, note, True) loss.backward() self.opt_generator.step() avg_generator_loss.update(loss) self.summary_writer.add_scalar("pre-train/Generator_loss", avg_generator_loss.val, self.epoch) return gen_note[:3] def train_wae(self, note, pre_note, pre_phrase, position, avg_generator_loss, avg_barZ_disc_loss, avg_phraseZ_disc_loss, fake_target, valid_target, curr_it): self.generator.train() self.z_discriminator_bar.train() self.z_discriminator_phrase.train() self.discriminator.eval() self.discriminator_feature.eval() self.generator.zero_grad() self.z_discriminator_bar.zero_grad() self.z_discriminator_phrase.zero_grad() if (self.epoch + curr_it) % 2: #################### Discriminator #################### self.free(self.z_discriminator_bar) self.free(self.z_discriminator_phrase) self.frozen(self.generator) self.frozen(self.discriminator) self.frozen(self.discriminator_feature) _, z, pre_z, phrase_feature, _ = self.generator(note, pre_note, pre_phrase, position) #### Phrase Feature ### phrase_fake = (torch.randn(phrase_feature.size(0), phrase_feature.size(1)) * self.config.sigma).cuda() d_phrase_fake = self.z_discriminator_phrase(phrase_fake).view(-1) d_phrase_real = self.z_discriminator_phrase(phrase_feature).view(-1) phraseZ_dics_loss = self.loss_phrase(d_phrase_real, fake_target) + \ self.loss_phrase(d_phrase_fake, valid_target) #### Bar Feature #### bar_fake = (torch.randn(z.size(0), z.size(1)) * self.config.sigma).cuda() d_bar_fake = self.z_discriminator_bar(bar_fake).view(-1) d_bar_real = self.z_discriminator_bar(z).view(-1) barZ_dics_loss = self.loss_bar(d_bar_real, fake_target) + self.loss_bar(d_bar_fake, valid_target) ####################### phraseZ_dics_loss.backward() barZ_dics_loss.backward() self.opt_Zdiscriminator_bar.step() self.opt_Zdiscriminator_phrase.step() avg_barZ_disc_loss.update(barZ_dics_loss) avg_phraseZ_disc_loss.update(phraseZ_dics_loss) #################### Generator #################### self.free(self.generator) self.frozen(self.z_discriminator_bar) self.frozen(self.z_discriminator_phrase) self.frozen(self.discriminator) self.frozen(self.discriminator_feature) gen_note, z, pre_z, phrase_feature, _ = self.generator(note, pre_note, pre_phrase, position) #### Phrase Encoder Loss ### d_phrase_real = self.z_discriminator_phrase(phrase_feature).view(-1) loss = self.loss_phrase(d_phrase_real, valid_target) #### Bar Encoder Loss #### d_bar_real1 = self.z_discriminator_bar(z).view(-1) d_bar_real2 = self.z_discriminator_bar(pre_z).view(-1) loss += self.loss_bar(d_bar_real1, valid_target) + self.loss_bar(d_bar_real2, valid_target) #### Generator Loss #### loss += self.loss_generator(gen_note, note, False) loss.backward() self.opt_generator.step() avg_generator_loss.update(loss) self.summary_writer.add_scalar("train-wae/Generator_loss", avg_generator_loss.val, self.epoch) self.summary_writer.add_scalar("train-wae/Bar_Z_Discriminator_loss", avg_barZ_disc_loss.val, self.epoch) self.summary_writer.add_scalar("train-wae/Phrase_Z_discriminator_loss", avg_phraseZ_disc_loss.val, self.epoch) return gen_note[:3] def train_gan(self, note, pre_note, pre_phrase, position, avg_generator_loss, avg_discriminator_loss, avg_feature_discriminator_loss, fake_target, valid_target, curr_it): self.generator.train() self.discriminator.train() self.discriminator_feature.train() self.z_discriminator_bar.eval() self.z_discriminator_phrase.eval() self.generator.zero_grad() self.discriminator.zero_grad() self.discriminator_feature.zero_grad() if (self.epoch + curr_it) % 2: #################### Discriminator #################### self.free(self.discriminator) self.free(self.discriminator_feature) self.frozen(self.generator) self.frozen(self.z_discriminator_bar) self.frozen(self.z_discriminator_phrase) gen_note, z, pre_z, phrase_feature, gen_z = self.generator(note, pre_note, pre_phrase, position) #### Note #### fake_note = torch.cat((pre_note, gen_note), dim=2) real_note = torch.cat((pre_note, note), dim=2) d_note_fake = self.discriminator(fake_note).view(-1) d_note_real = self.discriminator(real_note).view(-1) note_disc_loss = self.loss_disc(d_note_real, fake_target) + self.loss_disc(d_note_fake, valid_target) #### Note Feature #### d_feature_fake = self.discriminator_feature(gen_z).view(-1) d_feature_real = self.discriminator_feature(z).view(-1) feature_disc_loss = self.loss_disc(d_feature_real, fake_target) + \ self.loss_disc(d_feature_fake, valid_target) ####################### note_disc_loss.backward() feature_disc_loss.backward() self.opt_discriminator.step() self.opt_discriminator_feature.step() avg_discriminator_loss.update(note_disc_loss) avg_feature_discriminator_loss.update(feature_disc_loss) #################### Generator #################### self.free(self.generator) self.frozen(self.discriminator) self.frozen(self.discriminator_feature) self.frozen(self.z_discriminator_bar) self.frozen(self.z_discriminator_phrase) noise = torch.normal(0, 1.5, size=(note.size(0), 1152)) gen_note, gen_z = self.generator(noise, pre_note, pre_phrase, position, False) #### Gan Loss #### d_note_fake = self.discriminator(torch.cat((pre_note, gen_note), dim=2)).view(-1) loss = self.loss_disc(d_note_fake, valid_target) d_feature_fake = self.discriminator_feature(gen_z).view(-1) loss += self.loss_disc(d_feature_fake, valid_target) loss.backward() self.opt_generator.step() avg_generator_loss.update(loss) self.summary_writer.add_scalar("train-gan/Generator_loss", avg_generator_loss.val, self.epoch) self.summary_writer.add_scalar("train-gan/Bar_Discriminator_loss", avg_discriminator_loss.val, self.epoch) self.summary_writer.add_scalar("train-gan/Bar_Feature_discriminator_loss", avg_feature_discriminator_loss.val, self.epoch) return gen_note[:3]
class BarGen(object): def __init__(self, config): self.config = config self.pretraining_step_size = self.config.pretraining_step_size self.batch_size = self.config.batch_size self.logger = self.set_logger() # define dataloader self.dataset = NoteDataset(self.config.root_path, self.config) self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=1, pin_memory=self.config.pin_memory, collate_fn=self.make_batch) # define models ( generator and discriminator) self.generator = Model() self.discriminator = BarDiscriminator() self.z_discriminator_phrase = PhraseZDiscriminator() self.z_discriminator_bar = BarZDiscriminator() # define loss self.loss_gen = Loss().cuda() self.loss_disc = DLoss().cuda() # define lr self.lr_gen1 = self.config.learning_rate self.lr_gen2 = self.config.learning_rate self.lr_discriminator = self.config.learning_rate self.lr_Zdiscriminator_bar = self.config.learning_rate self.lr_Zdiscriminator_phrase = self.config.learning_rate # define optimizer self.opt_gen1 = torch.optim.Adam(self.generator.parameters(), lr=self.lr_gen1) self.opt_gen2 = torch.optim.Adam(self.generator.parameters(), lr=self.lr_gen2) self.opt_discriminator = torch.optim.Adam( self.discriminator.parameters(), lr=self.lr_discriminator) self.opt_Zdiscriminator_bar = torch.optim.Adam( self.z_discriminator_bar.parameters(), lr=self.lr_Zdiscriminator_bar) self.opt_Zdiscriminator_phrase = torch.optim.Adam( self.z_discriminator_phrase.parameters(), lr=self.lr_Zdiscriminator_phrase) # define optimize scheduler self.scheduler_gen1 = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_gen1, mode='min', factor=0.8, cooldown=6) self.scheduler_gen2 = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_gen2, mode='min', factor=0.8, cooldown=6) self.scheduler_discriminator = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_discriminator, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_bar = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_Zdiscriminator_bar, mode='min', factor=0.8, cooldown=6) self.scheduler_Zdiscriminator_phrase = torch.optim.lr_scheduler.ReduceLROnPlateau( self.opt_Zdiscriminator_phrase, mode='min', factor=0.8, cooldown=6) # initialize counter self.iteration = 0 self.epoch = 0 self.manual_seed = random.randint(1, 10000) torch.manual_seed(self.manual_seed) torch.cuda.manual_seed_all(self.manual_seed) random.seed(self.manual_seed) print("seed: ", self.manual_seed) # cuda setting gpu_list = list(range(self.config.gpu_cnt)) self.generator = nn.DataParallel(self.generator, device_ids=gpu_list) self.discriminator = nn.DataParallel(self.discriminator, device_ids=gpu_list) self.z_discriminator_bar = nn.DataParallel(self.z_discriminator_bar, device_ids=gpu_list) self.z_discriminator_phrase = nn.DataParallel( self.z_discriminator_phrase, device_ids=gpu_list) self.generator = self.generator.cuda() self.discriminator = self.discriminator.cuda() self.z_discriminator_bar = self.z_discriminator_bar.cuda() self.z_discriminator_phrase = self.z_discriminator_phrase.cuda() # Model Loading from the latest checkpoint if not found start from scratch. self.load_checkpoint(self.config.checkpoint_file) # Summary Writer self.summary_writer = SummaryWriter(log_dir=os.path.join( self.config.root_path, self.config.summary_dir), comment='BarGen') print('Number of generator parameters: {}'.format( sum([p.data.nelement() for p in self.generator.parameters()]))) print('Number of discriminator parameters: {}'.format( sum([p.data.nelement() for p in self.discriminator.parameters()]))) print('Number of barZ discriminator parameters: {}'.format( sum([ p.data.nelement() for p in self.z_discriminator_bar.parameters() ]))) print('Number of phraseZ discriminator parameters: {}'.format( sum([ p.data.nelement() for p in self.z_discriminator_phrase.parameters() ]))) def set_logger(self): logger = logging.getLogger() logger.setLevel(logging.DEBUG) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") file_handler = logging.FileHandler(filename="train_epoch.log") file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def make_batch(self, samples): note = np.concatenate([sample['note'] for sample in samples], axis=0) pre_note = np.concatenate([sample['pre_note'] for sample in samples], axis=0) pre_phrase = np.concatenate( [sample['pre_phrase'] for sample in samples], axis=0) position = np.concatenate([sample['position'] for sample in samples], axis=0) return tuple([ torch.tensor(note, dtype=torch.float), torch.tensor(pre_note, dtype=torch.float), torch.tensor(pre_phrase, dtype=torch.float), torch.tensor(position, dtype=torch.long) ]) def free(self, module: nn.Module): for p in module.parameters(): p.requires_grad = True def frozen(self, module: nn.Module): for p in module.parameters(): p.requires_grad = False def load_checkpoint(self, file_name): filename = os.path.join(self.config.root_path, self.config.checkpoint_dir, file_name) try: print("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.generator.load_state_dict(checkpoint['generator_state_dict']) self.opt_gen1.load_state_dict(checkpoint['gen_optimizer1']) self.opt_gen2.load_state_dict(checkpoint['gen_optimizer2']) self.discriminator.load_state_dict( checkpoint['discriminator_state_dict']) self.opt_discriminator.load_state_dict( checkpoint['disc_optimizer']) self.z_discriminator_bar.load_state_dict( checkpoint['z_discriminator_bar_state_dict']) self.opt_Zdiscriminator_bar.load_state_dict( checkpoint['opt_Zdiscriminator_bar_optimizer']) self.z_discriminator_phrase.load_state_dict( checkpoint['z_discriminator_phrase_state_dict']) self.opt_Zdiscriminator_phrase.load_state_dict( checkpoint['opt_Zdiscriminator_phrase_optimizer']) except OSError as e: print("No checkpoint exists from '{}'. Skipping...".format( self.config.checkpoint_dir)) print("**First time to train**") def save_checkpoint(self, file_name, epoch): tmp_name = os.path.join(self.config.root_path, self.config.checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) # file_name = os.path.join(self.config.root_path, self.config.checkpoint_dir, file_name) state = { 'epoch': self.epoch, 'generator_state_dict': self.generator.state_dict(), 'gen_optimizer1': self.opt_gen1.state_dict(), 'gen_optimizer2': self.opt_gen2.state_dict(), 'discriminator_state_dict': self.discriminator.state_dict(), 'disc_optimizer': self.opt_discriminator.state_dict(), 'z_discriminator_bar_state_dict': self.z_discriminator_bar.state_dict(), 'opt_Zdiscriminator_bar_optimizer': self.opt_Zdiscriminator_bar.state_dict(), 'z_discriminator_phrase_state_dict': self.z_discriminator_phrase.state_dict(), 'opt_Zdiscriminator_phrase_optimizer': self.opt_Zdiscriminator_phrase.state_dict(), 'lr_gen': self.config.learning_rate, 'lr_discriminator': self.config.learning_rate, 'lr_Zdiscriminator_bar': self.config.learning_rate, 'lr_Zdiscriminator_phrase': self.config.learning_rate, } torch.save(state, tmp_name) shutil.copyfile( tmp_name, os.path.join(self.config.root_path, self.config.checkpoint_dir, 'checkpoint.pth.tar')) def run(self): try: self.train() except KeyboardInterrupt: print("You have entered CTRL+C.. Wait to finalize") def train(self): for epoch in range(self.config.epoch): self.epoch += 1 self.train_epoch() if self.epoch > self.pretraining_step_size + 50: self.save_checkpoint(self.config.checkpoint_file, self.epoch) def train_epoch(self): tqdm_batch = tqdm(self.dataloader, total=self.dataset.num_iterations, desc="epoch-{}".format(self.epoch)) self.generator.train() self.discriminator.train() self.z_discriminator_bar.train() self.z_discriminator_phrase.train() image_sample = None Tensor = torch.cuda.FloatTensor avg_gen_loss = AverageMeter() avg_disc_loss = AverageMeter() avg_barZ_disc_loss = AverageMeter() avg_phraseZ_disc_loss = AverageMeter() div_flag = random.randrange(2, 5) for curr_it, (note, pre_note, pre_phrase, position) in enumerate(tqdm_batch): note = note.cuda(async=self.config.async_loading) pre_note = pre_note.cuda(async=self.config.async_loading) pre_phrase = pre_phrase.cuda(async=self.config.async_loading) position = position.cuda(async=self.config.async_loading) note = Variable(note) pre_note = Variable(pre_note) pre_phrase = Variable(pre_phrase) position = Variable(position) valid_target = Variable(Tensor(note.size(0)).fill_(1.0), requires_grad=False) fake_target = Variable(Tensor(note.size(0)).fill_(0.0), requires_grad=False) fake_target_double = Variable(Tensor(note.size(0) * 2).fill_(0.0), requires_grad=False) self.iteration += 1 #################### self.generator.zero_grad() self.discriminator.zero_grad() self.z_discriminator_bar.zero_grad() self.z_discriminator_phrase.zero_grad() if (curr_it + self.epoch ) % div_flag == 1 and self.epoch > self.pretraining_step_size: #################### Discriminator #################### self.free(self.discriminator) self.free(self.z_discriminator_bar) self.free(self.z_discriminator_phrase) self.frozen(self.generator) gen_note, z, pre_z, phrase_feature = self.generator( note, pre_note, pre_phrase, position) #### Phrase Feature ### phrase_fake = torch.randn(phrase_feature.size(0), phrase_feature.size(1)).cuda() d_phrase_fake = self.z_discriminator_phrase(phrase_fake).view( -1) d_phrase_real = self.z_discriminator_phrase( phrase_feature).view(-1) phraseZ_dics_loss = self.loss_disc( d_phrase_real, valid_target) + self.loss_disc( d_phrase_fake, fake_target) #### Bar Feature #### bar_fake = torch.randn(z.size(0) * 2, z.size(1)).cuda() d_bar_fake = self.z_discriminator_bar(bar_fake).view(-1) d_bar_real1 = self.z_discriminator_bar(z).view(-1) d_bar_real2 = self.z_discriminator_bar(pre_z).view(-1) barZ_dics_loss = self.loss_disc(d_bar_real1, valid_target) + self.loss_disc(d_bar_real2, valid_target) +\ self.loss_disc(d_bar_fake, fake_target_double) #### Generated Bar #### fake_note = torch.gt(gen_note, 0.3).type('torch.cuda.FloatTensor') fake_note = torch.cat((pre_note, fake_note), dim=2) d_fake = self.discriminator(fake_note).view(-1) real_note = torch.cat((pre_note, note), dim=2) d_real = self.discriminator(real_note).view(-1) disc_loss = self.loss_disc(d_fake, fake_target) + self.loss_disc( d_real, valid_target) ####################### disc_loss.backward() phraseZ_dics_loss.backward() barZ_dics_loss.backward() self.opt_discriminator.step() self.opt_Zdiscriminator_bar.step() self.opt_Zdiscriminator_phrase.step() avg_disc_loss.update(disc_loss) avg_barZ_disc_loss.update(barZ_dics_loss) avg_phraseZ_disc_loss.update(phraseZ_dics_loss) #################### Generator #################### self.free(self.generator) self.frozen(self.discriminator) self.frozen(self.z_discriminator_bar) self.frozen(self.z_discriminator_phrase) gen_note, z, pre_z, phrase_feature = self.generator( note, pre_note, pre_phrase, position) image_sample = gen_note origin_image = note #### add GAN Loss ### if self.epoch > self.pretraining_step_size: gen_loss = self.loss_gen(gen_note, note, False) gen_loss = self.loss_disc( self.z_discriminator_phrase(phrase_feature).view(-1), valid_target) gen_loss += self.loss_disc(self.z_discriminator_bar(z).view(-1), valid_target) + \ self.loss_disc(self.z_discriminator_bar(pre_z).view(-1), valid_target) fake_note = torch.gt(gen_note, 0.3).type('torch.cuda.FloatTensor') fake_note = torch.cat((pre_note, fake_note), dim=2) d_fake = self.discriminator(fake_note).view(-1) gen_loss += self.loss_disc(d_fake, valid_target) gen_loss.backward() self.opt_gen2.step() else: gen_loss = self.loss_gen(gen_note, note, True) gen_loss.backward() self.opt_gen1.step() avg_gen_loss.update(gen_loss.item()) if self.epoch > self.pretraining_step_size: self.summary_writer.add_scalar("train/Generator_loss", avg_gen_loss.val, self.iteration) self.summary_writer.add_scalar("train/Discriminator_loss", avg_disc_loss.val, self.iteration) self.summary_writer.add_scalar( "train/Bar_Z_Discriminator_loss", avg_barZ_disc_loss.val, self.iteration) self.summary_writer.add_scalar( "train/Phrase_Z_discriminator_loss", avg_phraseZ_disc_loss.val, self.iteration) else: self.summary_writer.add_scalar("pre_train/Generator_loss", avg_gen_loss.val, self.iteration) tqdm_batch.close() self.summary_writer.add_image("generated/sample 1_1", image_sample[0].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/sample 1_2", image_sample[1].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/sample 1_3", image_sample[2].reshape(1, 96, 60), self.epoch) image_sample = torch.gt(image_sample, 0.3).type('torch.cuda.FloatTensor') self.summary_writer.add_image("generated/sample 2_1", image_sample[0].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/sample 2_2", image_sample[1].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/sample 2_3", image_sample[2].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/origin 2_1", origin_image[0].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/origin 2_2", origin_image[1].reshape(1, 96, 60), self.epoch) self.summary_writer.add_image("generated/origin 2_3", origin_image[2].reshape(1, 96, 60), self.epoch) if self.epoch > self.pretraining_step_size: self.scheduler_gen1.step(avg_gen_loss.val) else: self.scheduler_gen2.step(avg_gen_loss.val) self.scheduler_discriminator.step(avg_disc_loss.val) self.scheduler_Zdiscriminator_bar.step(avg_barZ_disc_loss.val) self.scheduler_Zdiscriminator_phrase.step(avg_phraseZ_disc_loss.val) self.logger.debug( 'pre_train lr: {}, generator lr: {}, bar disc lr: {}, barZ disc lr: {}, phraseZ disc lr: {}' .format(self.lr_gen1, self.lr_gen2, self.lr_discriminator, self.lr_Zdiscriminator_bar, self.lr_Zdiscriminator_phrase))