Пример #1
0
class Trainer():
    def __init__(self, args):
        self.args = args
        self.train_writer = SummaryWriter('Logs/train')
        self.test_writer = SummaryWriter('Logs/test')
        self.wavenet = Wavenet(args, self.train_writer)
        self.train_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(), args.shuffle,
            args.num_workers, True)
        self.test_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(), args.shuffle,
            args.num_workers, False)
        self.wavenet.total = self.train_data_loader.__len__(
        ) * self.args.num_epochs
        self.load_last_checkpoint(self.args.resume)

    def load_last_checkpoint(self, resume=0):
        if resume > 0:
            self.wavenet.load('Checkpoints/' + str(resume) + '_large.pkl',
                              'Checkpoints/' + str(resume) + '_small.pkl')
        else:
            checkpoint_list = list(
                pathlib.Path('Checkpoints').glob('**/*.pkl'))
            checkpoint_list = [str(i) for i in checkpoint_list]
            if len(checkpoint_list) > 0:
                checkpoint_list.sort(key=natural_sort_key)
                self.wavenet.load(str(checkpoint_list[-2]),
                                  str(checkpoint_list[-1]))

    def run(self):
        with tqdm(range(self.args.num_epochs), dynamic_ncols=True) as pbar1:
            for epoch in pbar1:
                with tqdm(self.train_data_loader,
                          total=self.train_data_loader.__len__(),
                          dynamic_ncols=True) as pbar2:
                    for i, (x, nonzero, diff, nonzero_diff,
                            condition) in enumerate(pbar2):
                        step = i + epoch * self.train_data_loader.__len__()
                        current_large_loss, current_small_loss = self.wavenet.train(
                            x.cuda(non_blocking=True),
                            nonzero.cuda(non_blocking=True),
                            diff.cuda(non_blocking=True),
                            nonzero_diff.cuda(non_blocking=True),
                            condition.cuda(non_blocking=True),
                            step=step,
                            train=True)
                        pbar2.set_postfix(ll=current_large_loss,
                                          sl=current_small_loss)
                with torch.no_grad():
                    train_loss_large = train_loss_small = 0
                    with tqdm(self.test_data_loader,
                              total=self.test_data_loader.__len__(),
                              dynamic_ncols=True) as pbar2:
                        for x, nonzero, diff, nonzero_diff, condition in pbar2:
                            current_large_loss, current_small_loss = self.wavenet.train(
                                x.cuda(non_blocking=True),
                                nonzero.cuda(non_blocking=True),
                                diff.cuda(non_blocking=True),
                                nonzero_diff.cuda(non_blocking=True),
                                condition.cuda(non_blocking=True),
                                train=False)
                            train_loss_large += current_large_loss
                            train_loss_small += current_small_loss
                            pbar2.set_postfix(ll=current_large_loss,
                                              sl=current_small_loss)
                    train_loss_large /= self.test_data_loader.__len__()
                    train_loss_small /= self.test_data_loader.__len__()
                    #tqdm.write('Testing step Large Loss: {}'.format(train_loss_large))
                    #tqdm.write('Testing step Small Loss: {}'.format(train_loss_small))
                    pbar1.set_postfix(ll=train_loss_large, sl=train_loss_small)
                    end_step = (epoch + 1) * self.train_data_loader.__len__()
                    sampled_image = self.sample(num=1, name=end_step)
                    self.test_writer.add_scalar('Test/Testing large loss',
                                                train_loss_large, end_step)
                    self.test_writer.add_scalar('Test/Testing small loss',
                                                train_loss_small, end_step)
                    self.test_writer.add_image('Score/Sampled', sampled_image,
                                               end_step)
                    self.wavenet.save(end_step)
        self.test_writer.close()
        self.train_writer.close()

    def sample(self, num, name='Sample_{}'.format(int(time.time()))):
        for _ in tqdm(range(num), dynamic_ncols=True):
            init, nonzero, diff, nonzero_diff, condition = self.train_data_loader.dataset.__getitem__(
                np.random.randint(self.train_data_loader.__len__()))
            image = self.wavenet.sample(
                name,
                temperature=self.args.temperature,
                init=torch.Tensor(init).cuda(non_blocking=True),
                nonzero=torch.Tensor(nonzero).cuda(non_blocking=True),
                diff=torch.Tensor(diff).cuda(non_blocking=True),
                nonzero_diff=torch.Tensor(nonzero_diff).cuda(
                    non_blocking=True),
                condition=torch.Tensor(condition).cuda(non_blocking=True),
                length=self.args.length)
        return image
Пример #2
0
class Trainer():
    def __init__(self, args):
        self.args = args
        self.train_writer = SummaryWriter('Logs/train')
        self.test_writer = SummaryWriter('Logs/test')
        self.wavenet = Wavenet(args.layer_size, args.stack_size, args.channels,
                               args.residual_channels, args.dilation_channels,
                               args.skip_channels, args.end_channels,
                               args.out_channels, args.learning_rate,
                               self.train_writer)
        self.train_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(),
            self.wavenet.receptive_field, args.shuffle, args.num_workers, True)
        self.test_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(),
            self.wavenet.receptive_field, args.shuffle, args.num_workers,
            False)

    def load_last_checkpoint(self):
        checkpoint_list = list(pathlib.Path('Checkpoints').glob('**/*.pkl'))
        checkpoint_list = [str(i) for i in checkpoint_list]
        if len(checkpoint_list) > 0:
            checkpoint_list.sort(key=natural_sort_key)
            self.wavenet.load(str(checkpoint_list[-1]))

    def run(self):
        self.load_last_checkpoint()
        for epoch in tqdm(range(self.args.num_epochs)):
            for i, (sample,
                    real) in tqdm(enumerate(self.train_data_loader),
                                  total=self.train_data_loader.__len__()):
                step = i + epoch * self.train_data_loader.__len__()
                self.wavenet.train(
                    sample.cuda(), real.cuda(), step, True,
                    self.args.num_epochs * self.train_data_loader.__len__())
            with torch.no_grad():
                train_loss = 0
                for _, (sample,
                        real) in tqdm(enumerate(self.test_data_loader),
                                      total=self.test_data_loader.__len__()):
                    train_loss += self.wavenet.train(sample.cuda(),
                                                     real.cuda(),
                                                     train=False)
                train_loss /= self.test_data_loader.__len__()
                tqdm.write('Testing step Loss: {}'.format(train_loss))
                end_step = (epoch + 1) * self.train_data_loader.__len__()
                sample_init, _ = self.train_data_loader.dataset.__getitem__(
                    np.random.randint(self.train_data_loader.__len__()))
                sampled_image = self.wavenet.sample(end_step, init=sample_init)
                self.test_writer.add_scalar('Testing loss', train_loss,
                                            end_step)
                self.test_writer.add_image('Sampled', sampled_image, end_step)
                self.wavenet.save(end_step)

    def sample(self, num):
        self.load_last_checkpoint()
        with torch.no_grad():
            for _ in tqdm(range(num)):
                sample_init, _ = self.train_data_loader.dataset.__getitem__(
                    np.random.randint(self.train_data_loader.__len__()))
                self.wavenet.sample('Sample_{}'.format(int(time.time())),
                                    self.args.temperature, sample_init)