Ejemplo n.º 1
0
    def _train(self, model, train_loader, val_loader=None):
        criterions = {
            'autoencoder': nn.CrossEntropyLoss(),
            'generator': lambda t: -torch.mean(F.logsigmoid(t)),
            'discriminator': nn.BCEWithLogitsLoss()
        }

        optimizers = {
            'autoencoder':
            torch.optim.Adam(list(model.encoder.parameters()) +
                             list(model.decoder.parameters()),
                             lr=self.config.lr),
            'generator':
            torch.optim.Adam(model.encoder.parameters(), lr=self.config.lr),
            'discriminator':
            torch.optim.Adam(model.discriminator.parameters(),
                             lr=self.config.lr)
        }
        schedulers = {
            k: torch.optim.lr_scheduler.StepLR(v, self.config.step_size,
                                               self.config.gamma)
            for k, v in optimizers.items()
        }
        device = torch.device(self.config.device)
        log = Logger()

        for epoch in range(self.config.train_epochs):
            tqdm_data = tqdm(train_loader,
                             desc='Training (epoch #{})'.format(epoch))
            for scheduler in schedulers.values():
                scheduler.step()
            log.append(
                self._train_epoch(model, tqdm_data, criterions, optimizers))
            log.write(self.config.log_file)
            if val_loader is not None:
                tqdm_data = tqdm(val_loader,
                                 desc='Validation (epoch #{})'.format(epoch))
                self._train_epoch(model, tqdm_data, criterions)

            if epoch % self.config.save_frequency == 0:
                model.to('cpu')
                torch.save(
                    model.state_dict(),
                    self.config.model_save[:-3] + '_{0:03d}.pt'.format(epoch))
                model.to(device)
Ejemplo n.º 2
0
    def fit(self, model, data):
        def get_params():
            return (p for p in model.parameters() if p.requires_grad)

        model.train()
        log = Logger()
        n_epoch = self.config.num_epochs

        optimizer = optim.Adam(get_params(), lr=self.config.lr)
        for epoch in range(n_epoch):
            if epoch < self.config.kl_start:
                kl_w = 0
            else:
                kl_w = self.config.kl_w

            word_acc, topo_acc, assm_acc, steo_acc, all_kl = 0, 0, 0, 0, 0
            with tqdm.tqdm(data) as train_dataloader:
                train_dataloader.set_description('Train (epoch #{})'.format(epoch))

                for it, batch in enumerate(train_dataloader):
                    model.zero_grad()
                    loss, kl_div, wacc, tacc, sacc, dacc = model(batch, kl_w)
                    loss.backward()
                    optimizer.step()

                    word_acc += wacc
                    topo_acc += tacc
                    assm_acc += sacc
                    steo_acc += dacc
                    all_kl += kl_div

                    postfix = {'kl': all_kl / (it + 1),
                               'word': word_acc / (it + 1) * 100,
                               'topo': topo_acc / (it + 1) * 100,
                               'assm': assm_acc / (it + 1) * 100,
                               'steo': steo_acc / (it + 1) * 100}

                    train_dataloader.set_postfix(postfix)
            log.append(postfix)
            log.save(self.config.log_file)
            if epoch % self.config.save_frequency == 0:
                model.to('cpu')
                torch.save(model.state_dict(), self.config.model_save[:-3]+'_{0:03d}.pt'.format(epoch))
                model.to(device)
Ejemplo n.º 3
0
    def fit(self, model, data):
        def get_params():
            return (p for p in model.parameters() if p.requires_grad)

        if isinstance(data, tuple):
            train_dataloader = data[0]
            val_dataloader = data[1]
        else:
            train_dataloader = data
            val_dataloader = None

        num_epochs = self.config.num_epochs
        device = torch.device(self.config.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(get_params(), lr=self.config.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    self.config.step_size,
                                                    self.config.gamma)
        elog = Logger()
        for epoch in range(num_epochs):
            scheduler.step()
            model.train()
            train_dataloader = tqdm.tqdm(train_dataloader)
            train_dataloader.set_description('Train (epoch #{})'.format(epoch))

            loss = self._pass_data(model, train_dataloader, criterion,
                                   optimizer)
            elog.append({'loss': loss})
            if val_dataloader is not None:
                val_dataloader = tqdm.tqdm(val_dataloader)
                val_dataloader.set_description(
                    'Validation (epoch #{})'.format(epoch))

                self._pass_data(model, val_dataloader, criterion)

            if epoch % self.config.save_frequency == 0:
                model.to('cpu')
                torch.save(
                    model.state_dict(),
                    self.config.model_save[:-3] + '_{0:03d}.pt'.format(epoch))
                model.to(device)
            elog.save(self.config.log_file)
        torch.save(model.state_dict(), self.config.model_save)
Ejemplo n.º 4
0
    def fit(self, model, data):
        def get_params():
            return (p for p in model.vae.parameters() if p.requires_grad)

        model.train()

        n_epoch = self._n_epoch()
        kl_annealer = KLAnnealer(n_epoch, self.config)

        optimizer = optim.Adam(get_params(), lr=self.config.lr_start)
        lr_annealer = CosineAnnealingLRWithRestart(optimizer, self.config)

        device = torch.device(self.config.device)
        n_last = self.config.n_last
        elog, ilog = Logger(), Logger()

        for epoch in range(n_epoch):
            # Epoch start
            kl_weight = kl_annealer(epoch)

            # Iters
            T = tqdm.tqdm(data)
            for i, x in enumerate(T):
                # Forward
                kl_loss, recon_loss = model(x)
                loss = kl_weight * kl_loss + recon_loss

                # Backward
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(get_params(), self.config.grad_clipping)
                optimizer.step()

                # Log
                lr = optimizer.param_groups[0]['lr']
                ilog.append({
                    'epoch': epoch,
                    'kl_loss': kl_loss.item(),
                    'recon_loss': recon_loss.item(),
                    'loss': loss.item(),
                    'kl_weight': kl_weight,
                    'lr': lr
                })

                # Update T
                kl_loss_value = np.mean(ilog['kl_loss'][-n_last:])
                recon_loss_value = np.mean(ilog['recon_loss'][-n_last:])
                loss_value = np.mean(ilog['loss'][-n_last:])
                postfix = [
                    f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}',
                    f'recon={recon_loss_value:.5f})',
                    f'klw={kl_weight:.5f} lr={lr:.5f}'
                ]
                T.set_postfix_str(' '.join(postfix))
                T.set_description(f'Train (epoch #{epoch})')
                T.refresh()

            # Log
            elog.append({
                **{k: v
                   for k, v in ilog[-1].items() if 'loss' not in k}, 'kl_loss':
                kl_loss_value,
                'recon_loss': recon_loss_value,
                'loss': loss_value
            })

            # Save model at each epoch
            if epoch % self.config.save_frequency == 0:
                model.to('cpu')
                torch.save(
                    model.state_dict(),
                    self.config.model_save[:-3] + '_{0:03d}.pt'.format(epoch))
                model.to(device)

            elog.save(self.config.log_file)

            # Epoch end
            lr_annealer.step()

        torch.save(model.state_dict(), self.config.model_save)
        return elog, ilog