class Trainer(): def __init__(self, args): self.args = args self.train_writer = SummaryWriter('Logs/train') self.test_writer = SummaryWriter('Logs/test') self.wavenet = Wavenet(args, self.train_writer) self.train_data_loader = DataLoader( args.batch_size * torch.cuda.device_count(), args.shuffle, args.num_workers, True) self.test_data_loader = DataLoader( args.batch_size * torch.cuda.device_count(), args.shuffle, args.num_workers, False) self.wavenet.total = self.train_data_loader.__len__( ) * self.args.num_epochs self.load_last_checkpoint(self.args.resume) def load_last_checkpoint(self, resume=0): if resume > 0: self.wavenet.load('Checkpoints/' + str(resume) + '_large.pkl', 'Checkpoints/' + str(resume) + '_small.pkl') else: checkpoint_list = list( pathlib.Path('Checkpoints').glob('**/*.pkl')) checkpoint_list = [str(i) for i in checkpoint_list] if len(checkpoint_list) > 0: checkpoint_list.sort(key=natural_sort_key) self.wavenet.load(str(checkpoint_list[-2]), str(checkpoint_list[-1])) def run(self): with tqdm(range(self.args.num_epochs), dynamic_ncols=True) as pbar1: for epoch in pbar1: with tqdm(self.train_data_loader, total=self.train_data_loader.__len__(), dynamic_ncols=True) as pbar2: for i, (x, nonzero, diff, nonzero_diff, condition) in enumerate(pbar2): step = i + epoch * self.train_data_loader.__len__() current_large_loss, current_small_loss = self.wavenet.train( x.cuda(non_blocking=True), nonzero.cuda(non_blocking=True), diff.cuda(non_blocking=True), nonzero_diff.cuda(non_blocking=True), condition.cuda(non_blocking=True), step=step, train=True) pbar2.set_postfix(ll=current_large_loss, sl=current_small_loss) with torch.no_grad(): train_loss_large = train_loss_small = 0 with tqdm(self.test_data_loader, total=self.test_data_loader.__len__(), dynamic_ncols=True) as pbar2: for x, nonzero, diff, nonzero_diff, condition in pbar2: current_large_loss, current_small_loss = self.wavenet.train( x.cuda(non_blocking=True), nonzero.cuda(non_blocking=True), diff.cuda(non_blocking=True), nonzero_diff.cuda(non_blocking=True), condition.cuda(non_blocking=True), train=False) train_loss_large += current_large_loss train_loss_small += current_small_loss pbar2.set_postfix(ll=current_large_loss, sl=current_small_loss) train_loss_large /= self.test_data_loader.__len__() train_loss_small /= self.test_data_loader.__len__() #tqdm.write('Testing step Large Loss: {}'.format(train_loss_large)) #tqdm.write('Testing step Small Loss: {}'.format(train_loss_small)) pbar1.set_postfix(ll=train_loss_large, sl=train_loss_small) end_step = (epoch + 1) * self.train_data_loader.__len__() sampled_image = self.sample(num=1, name=end_step) self.test_writer.add_scalar('Test/Testing large loss', train_loss_large, end_step) self.test_writer.add_scalar('Test/Testing small loss', train_loss_small, end_step) self.test_writer.add_image('Score/Sampled', sampled_image, end_step) self.wavenet.save(end_step) self.test_writer.close() self.train_writer.close() def sample(self, num, name='Sample_{}'.format(int(time.time()))): for _ in tqdm(range(num), dynamic_ncols=True): init, nonzero, diff, nonzero_diff, condition = self.train_data_loader.dataset.__getitem__( np.random.randint(self.train_data_loader.__len__())) image = self.wavenet.sample( name, temperature=self.args.temperature, init=torch.Tensor(init).cuda(non_blocking=True), nonzero=torch.Tensor(nonzero).cuda(non_blocking=True), diff=torch.Tensor(diff).cuda(non_blocking=True), nonzero_diff=torch.Tensor(nonzero_diff).cuda( non_blocking=True), condition=torch.Tensor(condition).cuda(non_blocking=True), length=self.args.length) return image
class Trainer(): def __init__(self, args): self.args = args self.train_writer = SummaryWriter('Logs/train') self.test_writer = SummaryWriter('Logs/test') self.wavenet = Wavenet(args.layer_size, args.stack_size, args.channels, args.residual_channels, args.dilation_channels, args.skip_channels, args.end_channels, args.out_channels, args.learning_rate, self.train_writer) self.train_data_loader = DataLoader( args.batch_size * torch.cuda.device_count(), self.wavenet.receptive_field, args.shuffle, args.num_workers, True) self.test_data_loader = DataLoader( args.batch_size * torch.cuda.device_count(), self.wavenet.receptive_field, args.shuffle, args.num_workers, False) def load_last_checkpoint(self): checkpoint_list = list(pathlib.Path('Checkpoints').glob('**/*.pkl')) checkpoint_list = [str(i) for i in checkpoint_list] if len(checkpoint_list) > 0: checkpoint_list.sort(key=natural_sort_key) self.wavenet.load(str(checkpoint_list[-1])) def run(self): self.load_last_checkpoint() for epoch in tqdm(range(self.args.num_epochs)): for i, (sample, real) in tqdm(enumerate(self.train_data_loader), total=self.train_data_loader.__len__()): step = i + epoch * self.train_data_loader.__len__() self.wavenet.train( sample.cuda(), real.cuda(), step, True, self.args.num_epochs * self.train_data_loader.__len__()) with torch.no_grad(): train_loss = 0 for _, (sample, real) in tqdm(enumerate(self.test_data_loader), total=self.test_data_loader.__len__()): train_loss += self.wavenet.train(sample.cuda(), real.cuda(), train=False) train_loss /= self.test_data_loader.__len__() tqdm.write('Testing step Loss: {}'.format(train_loss)) end_step = (epoch + 1) * self.train_data_loader.__len__() sample_init, _ = self.train_data_loader.dataset.__getitem__( np.random.randint(self.train_data_loader.__len__())) sampled_image = self.wavenet.sample(end_step, init=sample_init) self.test_writer.add_scalar('Testing loss', train_loss, end_step) self.test_writer.add_image('Sampled', sampled_image, end_step) self.wavenet.save(end_step) def sample(self, num): self.load_last_checkpoint() with torch.no_grad(): for _ in tqdm(range(num)): sample_init, _ = self.train_data_loader.dataset.__getitem__( np.random.randint(self.train_data_loader.__len__())) self.wavenet.sample('Sample_{}'.format(int(time.time())), self.args.temperature, sample_init)