示例#1
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']
        ), "Number of datasets must match number of nodes"

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = torch.load(args.checkpoint)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(
                self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(
                self.discriminator)
            self.logger.info('Created DistributedDataParallel')
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(
                self.discriminator).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        if args.checkpoint and args.load_optimizer:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])
            self.d_optimizer.load_state_dict(states['d_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item()

        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - discriminator right
        z = self.encoder(x)
        z_logits = self.discriminator(z)
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.args.d_lambda
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(),
                             self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z = self.encoder(x_aug)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)
        discriminator_wrong = -F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)

        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        self.loss_total.reset()

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'discriminator_state': self.discriminator.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
                'd_optimizer_state': self.d_optimizer.state_dict()
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
示例#2
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)

        assert args.checkpoint, 'you MUST pass a checkpoint for the encoder'

        if args.continue_training:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
        else:
            self.start_epoch = 0

        states = torch.load(args.checkpoint)
        self.encoder.load_state_dict(states['encoder_state'])
        if args.continue_training:
            self.decoder.load_state_dict(states['decoder_state'])
        self.logger.info('Loaded checkpoint parameters')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(self.decoder.parameters(),
                                          lr=args.lr)

        if args.continue_training:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # optimize G - reconstructs well
        z = self.encoder(x_aug)
        z = z.detach()  # stop gradients
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = recon_loss.mean()
        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.encoder.eval()
        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        states = torch.load(self.args.checkpoint)

        torch.save(
            {
                'encoder_state': states['encoder_state'],
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
示例#3
0
class Finetuner:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(args.data)
        self.modelPath = Path('checkpoints') / args.expName

        self.logger = create_output_dir(args, self.modelPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.start_epoch = 0

        #torch.manual_seed(args.seed)
        #torch.cuda.manual_seed(args.seed)

        #get the pretrained model checkpoints
        checkpoint = args.checkpoint.parent.glob(args.checkpoint.name +
                                                 '_*.pth')
        checkpoint = [c for c in checkpoint
                      if extract_id(c) in args.decoder][0]

        model_args = torch.load(args.checkpoint.parent / 'args.pth')[0]

        self.encoder = Encoder(model_args)
        self.decoder = WaveNet(model_args)

        self.encoder = Encoder(model_args)
        self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state'])

        #encoder freeze
        for param in self.encoder.parameters():
            param.requires_grad = False
            #self.logger.debug(f'encoder at start: {param}')

        self.decoder = WaveNet(model_args)
        self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])

        #decoder freeze
        for param in self.decoder.layers[:-args.decoder_update].parameters():
            param.requires_grad = False
            #self.logger.debug(f'decoder at start: {param}')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()
        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.step()

    def train_batch(self, x, x_aug, dset_num):
        'train batch without considering the discriminator'
        x = x.float()
        x_aug = x_aug.float()
        z = self.encoder(x_aug)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
        loss = recon_loss.mean()

        self.model_optimizer.zero_grad()
        loss.backward()
        self.model_optimizer.step()
        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()
        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def finetune(self):
        best_eval = float('inf')

        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args, epoch], '%s/args.pth' % self.modelPath)
            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.modelPath / filename

        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')