def __init__(self, config, train: DataLoader, val: DataLoader):
     self.config = config
     self.train_dataset = train
     self.val_dataset = val
     self.adv_lambda = config['model']['adv_lambda']
     self.metric_counter = MetricCounter(config['experiment_desc'])
     self.warmup_epochs = config['warmup_num']
Beispiel #2
0
 def __init__(self, config, train: DataLoader, val: DataLoader):
     self.config = config
     self.train_dataset = train
     self.val_dataset = val
     self.adv_lambda = config['model']['adv_lambda']
     self.metric_counter = MetricCounter(config['experiment_desc'])
     self.warmup_epochs = config['warmup_num']
     gpu_id = self.config['gpu_id']
     self.device = torch.device('cuda:{}'.format(gpu_id) if (torch.cuda.is_available() and gpu_id > 0) else "cpu")
Beispiel #3
0
 def __init__(self, config):
     self.config = config
     self.train_dataset = self._get_dataset(config, 'train')
     self.val_dataset = self._get_dataset(config, 'test')
     self.experiment_name = config['experiment_desc'] + '_' + config[
         'model']['g_name']
     self.experiment_name += '_content' + str(config['model']['content_coef'])+ '_feature' \
           + str(config['model']['feature_coef']) + '_adv' \
           + str(config['model']['adv_coef'])
     self.metric_counter = MetricCounter(self.experiment_name)
     self.warmup_epochs = config['warmup_num']
class Validator:
    def __init__(self, config, val: DataLoader):
        self.config = config
        self.val_dataset = val
        self.metric_counter = MetricCounter()

    def validate(self):
        self._init_params()
        self.netG.load_state_dict(torch.load('best_G_fpn.h5')['model'])
        self.netG.train(True)

        self._validate()
        torch.cuda.empty_cache()

        print(self.metric_counter.loss_message())

    def _validate(self):
        self.metric_counter.clear()
        epoch_size = config.get('val_batches_per_epoch') or len(
            self.train_dataset)
        tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
        tq.set_description('Validation')
        i = 0
        total_psnr = 0
        total_ssim = 0
        total_samples = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = self.netG(inputs)
            curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(
                inputs, outputs, targets)
            total_ssim += curr_ssim * len(inputs)
            total_psnr += curr_psnr * len(inputs)
            total_samples += len(inputs)

            print("Metrcis:", curr_psnr, curr_ssim)
            print("Totals:", total_ssim, total_psnr, total_samples)
            print("nan:", np.isnan(img_for_vis).any())
            print("inf:", np.isinf(img_for_vis).any())
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            if not i:
                self.metric_counter.add_image(img_for_vis, tag='val')
            i += 1
            if i > epoch_size:
                break
            self.metric_counter.write_to_tensorboard(i, validation=True)
            del inputs, targets, outputs

        print("PSNR", total_psnr / total_samples)
        print("SSIM", total_ssim / total_samples)

        tq.close()

    def _init_params(self):
        self.netG, netD = get_nets(self.config['model'])
        self.netG.cuda()
        self.model = get_model(self.config['model'])
Beispiel #5
0
    def __init__(self, config):
        self.config = config
        self.dataset = self._get_dataset(config["dataroot"],
                                         config["seq_length"])
        dlen = len(self.dataset)
        splitlen = [int(0.8 * dlen), int(0.1 * dlen), 0]
        splitlen[2] = dlen - sum(splitlen)

        self.train_dataset, self.val_dataset, self.test_dataset = data.random_split(
            self.dataset, splitlen)

        self.experiment_name = f"{config['experiment_desc']}_{config['model']['model_n']}"
        self.metric_counter = MetricCounter(self.experiment_name,
                                            self.config["print_every"])
Beispiel #6
0
def train():
    prepare_directories()
    metric_counter = MetricCounter(exp_name=LOG_DIR)

    parameters = dict(audio_configs)
    parameters["text_cleaner"] = configs["text_cleaner"]
    parameters["outputs_per_step"] = configs["r"]
    train_dataset = TextSpeechDataset(root_dir=data_configs["data_path"],
                                      annotations_file=data_configs["annotations_train"], parameters=parameters)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=train_dataset.collate_fn,
                              num_workers=num_workers, drop_last=False, pin_memory=True)

    val_dataset = TextSpeechDataset(root_dir=data_configs["data_path"],
                                    annotations_file=data_configs["annotations_val"], parameters=parameters)
    val_loader = DataLoader(dataset=val_dataset, batch_size=eval_batch_size, num_workers=num_workers,
                            collate_fn=val_dataset.collate_fn, drop_last=False, pin_memory=True)

    model = Tacotron(embedding_dim=configs.pop("embedding_size"),
                     linear_dim=audio_configs["frequency"],
                     mel_dim=audio_configs["mels_size"],
                     r=configs.pop("r"))

    if use_cuda:
        model = torch.nn.DataParallel(model.to("cuda"))

    optimizer = optim.Adam(params=model.parameters(), lr=train_configs["lr"])
    criterion = L1LossMasked()

    if args.resume:
        model.load_state_dict(torch.load(args.resume))

    n_priority_freq = int(3000 / (audio_configs["sample_rate"] * 0.5) * audio_configs["frequency"])
    for epoch in range(train_configs["epochs"]):
        audio_signal = run_epoch(model, train_loader, optimizer, criterion, metric_counter, epoch, n_priority_freq)
        run_validate(model, val_loader, criterion, metric_counter, n_priority_freq)
        if metric_counter.update_best_model():
            torch.save(model.state_dict(), os.path.join(os.path.join(WEIGHTS_SAVE_PATH,
                                                                     f"best_{configs['experiment_name']}.pth.tar")))
            audio_signal = train_dataset.ap.spectrogram_to_wav(audio_signal.T)
            metric_counter.write_audio_to_tensorboard("Audio", audio_signal, epoch, audio_configs["sample_rate"])

        torch.save(model.state_dict(), os.path.join(os.path.join(WEIGHTS_SAVE_PATH,
                                                                 f"last_{configs['experiment_name']}.pth.tar")))
        print(metric_counter.loss_message())
        logging.debug(
            f"Experiment Name: {configs['experiment_name']}, Epoch: {epoch}, Loss: {metric_counter.loss_message()}")
Beispiel #7
0
class Trainer(object):
    def __init__(self, config):
        self.config = config
        self.train_dataset = self._get_dataset(config, 'train')
        self.val_dataset = self._get_dataset(config, 'test')
        self.experiment_name = config['experiment_desc'] + '_' + config[
            'model']['g_name']
        self.experiment_name += '_content' + str(config['model']['content_coef'])+ '_feature' \
              + str(config['model']['feature_coef']) + '_adv' \
              + str(config['model']['adv_coef'])
        self.metric_counter = MetricCounter(self.experiment_name)
        self.warmup_epochs = config['warmup_num']

    def train(self):
        self._init_params()
        for epoch in range(0, config['num_epochs']):
            if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
                self.netG.module.unfreeze()
                self.optimizer_G = self._get_optim(self.netG.parameters())
                self.scheduler_G = self._get_scheduler(self.optimizer_G)
            self._run_epoch(epoch)
            self._validate(epoch)
            self.scheduler_G.step()
            self.scheduler_D.step()

            if self.metric_counter.update_best_model():
                torch.save({'model': self.netG.state_dict()},
                           'best_{}.h5'.format(self.experiment_name))
            torch.save({'model': self.netG.state_dict()},
                       'last_{}.h5'.format(self.experiment_name))
            print(self.metric_counter.loss_message())
            logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" %
                          (self.experiment_name, epoch,
                           self.metric_counter.loss_message()))

    def _run_epoch(self, epoch):
        self.metric_counter.clear()
        for param_group in self.optimizer_G.param_groups:
            lr = param_group['lr']
        tq = tqdm.tqdm(self.train_dataset.dataloader)
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_D = self._update_d(outputs, targets)
            self.optimizer_G.zero_grad()
            loss_content, loss_vgg = self.criterionG[0](
                outputs, targets), self.criterionG[1](outputs, targets)
            loss_adv = self.adv_trainer.lossG(outputs, targets)
            loss_G = self.config['model']['content_coef'] * loss_content + \
               self.config['model']['adv_coef'] * loss_adv + \
               self.config['model']['feature_coef'] * loss_vgg
            loss_G.backward()
            self.optimizer_G.step()
            self.metric_counter.add_losses(loss_G.item(), loss_content.item(),
                                           loss_D)
            curr_psnr, curr_ssim = self.model.get_acc(outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            tq.set_postfix(loss=self.metric_counter.loss_message())
            i += 1
            if i == 3:
                self.metric_counter.images_to_tensorboard(
                    [inputs[0], targets[0], outputs[0]], epoch)
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch)

    def _validate(self, epoch):
        self.metric_counter.clear()
        tq = tqdm.tqdm(self.val_dataset.dataloader)
        tq.set_description('Validation')
        i = 0
        for data in tq:
            i += 1
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_content, loss_vgg = self.criterionG[0](
                outputs, targets), self.criterionG[1](outputs, targets)
            loss_adv = self.adv_trainer.lossG(outputs, targets)
            loss_G = self.config['model']['content_coef'] * loss_content + \
               self.config['model']['adv_coef'] * loss_adv + \
               self.config['model']['feature_coef'] * loss_vgg
            self.metric_counter.add_losses(loss_G.item(), loss_content.item(),
                                           loss_vgg.item())
            curr_psnr, curr_ssim = self.model.get_acc(outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)


# 			if i == 3:
# 				self.metric_counter.images_to_tensorboard([inputs[0], targets[0], outputs[0]], epoch)
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch, validation=True)

    def _get_dataset(self, config, filename):
        return CustomDataLoader(config, filename)

    def _update_d(self, outputs, targets):
        if self.config['model']['d_name'] == 'no_gan':
            return 0
        self.optimizer_D.zero_grad()
        loss_D = self.config['model']['adv_coef'] * self.adv_trainer.lossD(
            outputs, targets)
        loss_D.backward(retain_graph=True)
        self.optimizer_D.step()
        return loss_D.item()

    def _get_optim(self, params):
        if self.config['optimizer']['name'] == 'adam':
            optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'sgd':
            optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'adadelta':
            optimizer = optim.Adadelta(params,
                                       lr=self.config['optimizer']['lr'])
        else:
            raise ValueError("Optimizer [%s] not recognized." %
                             self.config['optimizer']['name'])
        return optimizer

    def _get_scheduler(self, optimizer):
        if self.config['scheduler']['name'] == 'plateau':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                patience=self.config['scheduler']['patience'],
                factor=self.config['scheduler']['factor'],
                min_lr=self.config['scheduler']['min_lr'])
        elif self.config['optimizer']['name'] == 'sgdr':
            scheduler = WarmRestart(optimizer)
        elif self.config['scheduler']['name'] == 'linear':
            scheduler = LinearDecay(
                optimizer,
                min_lr=self.config['scheduler']['min_lr'],
                num_epochs=self.config['num_epochs'],
                start_epoch=self.config['scheduler']['start_epoch'])
        else:
            raise ValueError("Scheduler [%s] not recognized." %
                             self.config['scheduler']['name'])
        return scheduler

    def _get_adversarial_trainer(self, D_name, netD, criterionD):
        if D_name == 'no_gan':
            return AdversarialTrainerFactory.createModel(
                'NoAdversarialTrainer')
        elif D_name == 'patch_gan' or D_name == 'multi_scale':
            return AdversarialTrainerFactory.createModel(
                'SingleAdversarialTrainer', netD, criterionD)
        elif D_name == 'double_gan':
            return AdversarialTrainerFactory.createModel(
                'DoubleAdversarialTrainer', netD, criterionD)
        else:
            raise ValueError("Discriminator Network [%s] not recognized." %
                             D_name)

    def _init_params(self):
        self.criterionG, criterionD = get_loss(self.config['model'])
        self.netG, netD = get_nets(self.config['model'])
        self.netG.cuda()
        self.adv_trainer = self._get_adversarial_trainer(
            self.config['model']['d_name'], netD, criterionD)
        self.model = get_model(self.config['model'])
        self.optimizer_G = self._get_optim(
            filter(lambda p: p.requires_grad, self.netG.parameters()))
        self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
        self.scheduler_G = self._get_scheduler(self.optimizer_G)
        self.scheduler_D = self._get_scheduler(self.optimizer_D)
Beispiel #8
0
class Trainer:
    def __init__(self, config, train: DataLoader, val: DataLoader):
        self.config = config
        self.train_dataset = train
        self.val_dataset = val
        self.adv_lambda = config['model']['adv_lambda']
        self.metric_counter = MetricCounter(config['experiment_desc'])
        self.warmup_epochs = config['warmup_num']
        gpu_id = self.config['gpu_id']
        self.device = torch.device('cuda:{}'.format(gpu_id) if (torch.cuda.is_available() and gpu_id > 0) else "cpu")

    def train(self):
        self._init_params()
        for epoch in range(0, config['num_epochs']):
            if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
                self.netG.module.unfreeze()
                self.optimizer_G = self._get_optim(self.netG.parameters())
                self.scheduler_G = self._get_scheduler(self.optimizer_G)
            self._run_epoch(epoch)
            self._validate(epoch)
            self.scheduler_G.step()
            self.scheduler_D.step()

            if self.metric_counter.update_best_model():
                torch.save({
                    'model': self.netG.state_dict()
                }, 'best_{}.h5'.format(self.config['experiment_desc']))
            torch.save({
                'model': self.netG.state_dict()
            }, 'last_{}.h5'.format(self.config['experiment_desc']))
            print(self.metric_counter.loss_message())
            logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
                self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))

    def _run_epoch(self, epoch):
        self.metric_counter.clear()
        for param_group in self.optimizer_G.param_groups:
            lr = param_group['lr']

        epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
        tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_D = self._update_d(outputs, targets)
            self.optimizer_G.zero_grad()
            loss_content = self.criterionG(outputs, targets)
            loss_adv = self.adv_trainer.loss_g(outputs, targets)
            loss_G = loss_content + self.adv_lambda * loss_adv
            loss_G.backward()
            self.optimizer_G.step()
            self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
            curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            tq.set_postfix(loss=self.metric_counter.loss_message())
            if not i:
                self.metric_counter.add_image(img_for_vis, tag='train')
            i += 1
            if i > epoch_size:
                break
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch)

    def _validate(self, epoch):
        self.metric_counter.clear()
        epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
        tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
        tq.set_description('Validation')
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_content = self.criterionG(outputs, targets)
            loss_adv = self.adv_trainer.loss_g(outputs, targets)
            loss_G = loss_content + self.adv_lambda * loss_adv
            self.metric_counter.add_losses(loss_G.item(), loss_content.item())
            curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            if not i:
                self.metric_counter.add_image(img_for_vis, tag='val')
            i += 1
            if i > epoch_size:
                break
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch, validation=True)

    def _update_d(self, outputs, targets):
        if self.config['model']['d_name'] == 'no_gan':
            return 0
        self.optimizer_D.zero_grad()
        loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
        loss_D.backward(retain_graph=True)
        self.optimizer_D.step()
        return loss_D.item()

    def _get_optim(self, params):
        if self.config['optimizer']['name'] == 'adam':
            optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'sgd':
            optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'adadelta':
            optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
        else:
            raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
        return optimizer

    def _get_scheduler(self, optimizer):
        if self.config['scheduler']['name'] == 'plateau':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                             mode='min',
                                                             patience=self.config['scheduler']['patience'],
                                                             factor=self.config['scheduler']['factor'],
                                                             min_lr=self.config['scheduler']['min_lr'])
        elif self.config['optimizer']['name'] == 'sgdr':
            scheduler = WarmRestart(optimizer)
        elif self.config['scheduler']['name'] == 'linear':
            scheduler = LinearDecay(optimizer,
                                    min_lr=self.config['scheduler']['min_lr'],
                                    num_epochs=self.config['num_epochs'],
                                    start_epoch=self.config['scheduler']['start_epoch'])
        else:
            raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
        return scheduler

    @staticmethod
    def _get_adversarial_trainer(d_name, net_d, criterion_d):
        if d_name == 'no_gan':
            return GANFactory.create_model('NoGAN')
        elif d_name == 'patch_gan' or d_name == 'multi_scale':
            return GANFactory.create_model('SingleGAN', net_d, criterion_d)
        elif d_name == 'double_gan':
            return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
        else:
            raise ValueError("Discriminator Network [%s] not recognized." % d_name)

    def _init_params(self):
        self.criterionG, criterionD = get_loss(self.config['model'])
        self.netG, netD = get_nets(self.config['model'])
        self.netG.to(self.device)
        self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
        self.model = get_model(self.config['model'])
        self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
        self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
        self.scheduler_G = self._get_scheduler(self.optimizer_G)
        self.scheduler_D = self._get_scheduler(self.optimizer_D)
Beispiel #9
0
class Trainer(object):
    def __init__(self, config):
        self.config = config
        self.dataset = self._get_dataset(config["dataroot"],
                                         config["seq_length"])
        dlen = len(self.dataset)
        splitlen = [int(0.8 * dlen), int(0.1 * dlen), 0]
        splitlen[2] = dlen - sum(splitlen)

        self.train_dataset, self.val_dataset, self.test_dataset = data.random_split(
            self.dataset, splitlen)

        self.experiment_name = f"{config['experiment_desc']}_{config['model']['model_n']}"
        self.metric_counter = MetricCounter(self.experiment_name,
                                            self.config["print_every"])

    def test(self):
        PATH = f'model/pretrained/best_{self.config["experiment_desc"]}.pth'
        self.model.load_state_dict(torch.load(PATH)['model'])
        self.model.eval()
        self._validate(-1, True)

    def train(self):

        self._init_params()
        # PATH = f'model/pretrained/best_{self.config["experiment_desc"]}.pth'
        # self.model.load_state_dict(torch.load(PATH)['model'])
        self.model.cuda()

        for epoch in range(0, self.epochs):
            if self.config['model']['model_n'] == "semichar_rnn":
                self._run_epoch(epoch)
                self._validate(epoch)
            if self.config['model']['model_n'] == "seq2seq+attention":
                self._run_epoch_seq2seq(epoch)
                self._validate_seq2seq(epoch)

            if self.metric_counter.update_best_model():
                torch.save(
                    {
                        'model': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'dataset_params': self.dataset.get_params(),
                        'n_hidden': self.config['model']['n_hidden'],
                        'n_layers': self.config['model']['n_layers']
                    },
                    f'model/pretrained/best_{self.config["experiment_desc"]}.pth'
                )

            print(self.metric_counter.loss_message())
            logging.debug(
                f"Experiment Name: {self.config['experiment_desc']}, Epoch: {epoch}, Loss: {self.metric_counter.loss_message()}"
            )

    def _run_epoch_seq2seq(self, epoch):
        self.metric_counter.clear()

        counter = 0
        for lenX, X, leny, y, y1hot in self.loader_train:
            h = None
            lenX, perm_idx = lenX.sort(0, descending=True)
            X = X[perm_idx]
            y = y[perm_idx]
            leny = leny[perm_idx]
            y1hot = y1hot[perm_idx]

            X, y = X.float().cuda(), y.float().cuda()
            lenX, leny = lenX.cuda(), leny.cuda()
            y1hot = y1hot.cuda()

            self.optimizer.zero_grad()
            out = self.model(h, lenX, X, leny, y)
            tar = y1hot[:, :out.size(1)]
            tar = tar.contiguous().view(tar.nelement())
            cur = out.contiguous().view(tar.nelement(), -1)
            loss = self.loss_fn(cur, tar)

            loss.backward()

            acc = self.calc_acc(cur, tar)
            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.config['clip'])
            self.optimizer.step()

            self.metric_counter.add_losses(loss)
            self.metric_counter.add_acc(acc)
            if not counter % self.config['print_every']:
                print(
                    f"Epoch: {epoch}; Train Step: {counter}: {self.metric_counter.loss_message()}"
                )

            counter += 1

        self.metric_counter.write_to_tensorboard(epoch)

    def _validate_seq2seq(self, epoch):
        self.metric_counter.clear()

        counter = 0
        self.model.eval()
        loader = self.loader_val
        for lenX, X, leny, y, y1hot in loader:
            h = None
            lenX, perm_idx = lenX.sort(0, descending=True)
            X = X[perm_idx]
            y = y[perm_idx]
            leny = leny[perm_idx]
            y1hot = y1hot[perm_idx]

            X, y = X.float().cuda(), y.float().cuda()
            lenX, leny = lenX.cuda(), leny.cuda()
            y1hot = y1hot.cuda()

            out = self.model(h, lenX, X, leny, y)
            tar = y1hot[:, :out.size(1)]
            tar = tar.contiguous().view(tar.nelement())
            cur = out.contiguous().view(tar.nelement(), -1)
            loss = self.loss_fn(cur, tar)

            acc = self.calc_acc(cur, tar)

            self.metric_counter.add_losses(loss)
            self.metric_counter.add_acc(acc)
            if not counter % self.config['print_every']:
                print(
                    f"Epoch: {epoch}; Valid Step: {counter}: {self.metric_counter.loss_message()}"
                )

            counter += 1

        self.metric_counter.write_to_tensorboard(epoch, validation=True)
        print(self.metric_counter.loss_message())

        self.model.train()

    def _run_epoch(self, epoch):

        self.metric_counter.clear()
        h = None
        counter = 0
        for X, y in self.loader_train:
            X, y = X.cuda(), y.cuda()

            self.optimizer.zero_grad()
            if h is not None:
                h = tuple([x[:, :X.size(0)].contiguous() for x in h])
            output, h = self.model(X, h)

            loss = self.loss_fn(output, y.view(y.nelement()).long())
            loss.backward()

            h = tuple([i.detach_() for i in h])
            acc = self.calc_acc(output, y.view(y.nelement()))

            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.config['clip'])
            self.optimizer.step()

            self.metric_counter.add_losses(loss)
            self.metric_counter.add_acc(acc)
            if not counter % self.config['print_every']:
                print(
                    f"Epoch: {epoch}; Train Step: {counter}: {self.metric_counter.loss_message()}"
                )

            counter += 1

        self.metric_counter.write_to_tensorboard(epoch)

    def _validate(self, epoch, test=False):
        h = None
        self.model.eval()
        self.metric_counter.clear()
        counter = 0
        loader = (self.loader_test if test else self.loader_val)

        for X, y in loader:
            X, y = X.cuda(), y.cuda()
            if h is not None:
                h = tuple([x[:, :X.size(0)].contiguous() for x in h])
            output, h = self.model(X, h)

            loss = self.loss_fn(output, y.view(y.nelement()).long())
            acc = self.calc_acc(output, y.view(y.nelement()))

            self.metric_counter.add_losses(loss)
            self.metric_counter.add_acc(acc)
            if not counter % self.config['print_every']:
                print(
                    f"Epoch: {epoch}; Valid Step: {counter}: {self.metric_counter.loss_message()}"
                )
            counter += 1

        if not test:
            self.metric_counter.write_to_tensorboard(epoch, validation=True)
        else:
            print(self.metric_counter.loss_message())

        self.model.train()

    def calc_acc(self, y_pred, y):
        acc_output = y_pred.cpu().detach().numpy().astype("int32")
        acc = np.equal(acc_output.argmax(1), y.cpu().numpy())
        return acc.sum() / acc_output.shape[0]

    def _get_dataset(self, dataroot, seq_length):
        if self.config['dataset'] == "autocomplete":
            return AutoCompleteDataset(dataroot, seq_length)
        if self.config['model']['model_n'] == "semichar_rnn":
            return SemicharDataset(dataroot, seq_length)
        if self.config['model']['model_n'] == "seq2seq+attention":
            return seq2seqDataset(dataroot, seq_length, 50, 1000)

    def _get_optim(self, params):
        if self.config['optimizer']['name'] == 'adam':
            optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'sgd':
            optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
        else:
            raise ValueError(
                f"Optimizer {self.config['optimizer']['name']} not recognized."
            )
        return optimizer

    def _init_params(self):

        self.loader_train = data.DataLoader(
            self.train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True)
        self.loader_val = data.DataLoader(self.val_dataset,
                                          batch_size=self.config['batch_size'],
                                          shuffle=False)

        self.loader_test = data.DataLoader(
            self.test_dataset,
            batch_size=self.config['batch_size'],
            shuffle=False)
        self.model = get_model(self.config['model']['model_n'],
                               self.dataset.get_params(),
                               self.config['model']['n_hidden'],
                               self.config['model']['n_layers'],
                               self.config['model']['drop_prob'],
                               self.config['model']['grad_clip'],
                               self.config['model']['attn_model'],
                               self.config['model']['dim'],
                               self.config['model']['vs'])

        self.epochs = self.config['num_epochs']
        self.optimizer = self._get_optim(self.model.parameters())
        self.loss_fn = get_loss(self.config['model']['loss'])
 def __init__(self, config, val: DataLoader):
     self.config = config
     self.val_dataset = val
     self.metric_counter = MetricCounter()