class Trainer: def __init__(self, config, train: DataLoader, val: DataLoader): self.config = config self.train_dataset = train self.val_dataset = val self.adv_lambda = config['model']['adv_lambda'] self.metric_counter = MetricCounter(config['experiment_desc']) self.warmup_epochs = config['warmup_num'] gpu_id = self.config['gpu_id'] self.device = torch.device('cuda:{}'.format(gpu_id) if (torch.cuda.is_available() and gpu_id > 0) else "cpu") def train(self): self._init_params() for epoch in range(0, config['num_epochs']): if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): self.netG.module.unfreeze() self.optimizer_G = self._get_optim(self.netG.parameters()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self._run_epoch(epoch) self._validate(epoch) self.scheduler_G.step() self.scheduler_D.step() if self.metric_counter.update_best_model(): torch.save({ 'model': self.netG.state_dict() }, 'best_{}.h5'.format(self.config['experiment_desc'])) torch.save({ 'model': self.netG.state_dict() }, 'last_{}.h5'.format(self.config['experiment_desc'])) print(self.metric_counter.loss_message()) logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) def _run_epoch(self, epoch): self.metric_counter.clear() for param_group in self.optimizer_G.param_groups: lr = param_group['lr'] epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset) tq = tqdm.tqdm(self.train_dataset, total=epoch_size) tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_D = self._update_d(outputs, targets) self.optimizer_G.zero_grad() loss_content = self.criterionG(outputs, targets) loss_adv = self.adv_trainer.loss_g(outputs, targets) loss_G = loss_content + self.adv_lambda * loss_adv loss_G.backward() self.optimizer_G.step() self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) tq.set_postfix(loss=self.metric_counter.loss_message()) if not i: self.metric_counter.add_image(img_for_vis, tag='train') i += 1 if i > epoch_size: break tq.close() self.metric_counter.write_to_tensorboard(epoch) def _validate(self, epoch): self.metric_counter.clear() epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset) tq = tqdm.tqdm(self.val_dataset, total=epoch_size) tq.set_description('Validation') i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_content = self.criterionG(outputs, targets) loss_adv = self.adv_trainer.loss_g(outputs, targets) loss_G = loss_content + self.adv_lambda * loss_adv self.metric_counter.add_losses(loss_G.item(), loss_content.item()) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) if not i: self.metric_counter.add_image(img_for_vis, tag='val') i += 1 if i > epoch_size: break tq.close() self.metric_counter.write_to_tensorboard(epoch, validation=True) def _update_d(self, outputs, targets): if self.config['model']['d_name'] == 'no_gan': return 0 self.optimizer_D.zero_grad() loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets) loss_D.backward(retain_graph=True) self.optimizer_D.step() return loss_D.item() def _get_optim(self, params): if self.config['optimizer']['name'] == 'adam': optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'sgd': optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'adadelta': optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr']) else: raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) return optimizer def _get_scheduler(self, optimizer): if self.config['scheduler']['name'] == 'plateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=self.config['scheduler']['patience'], factor=self.config['scheduler']['factor'], min_lr=self.config['scheduler']['min_lr']) elif self.config['optimizer']['name'] == 'sgdr': scheduler = WarmRestart(optimizer) elif self.config['scheduler']['name'] == 'linear': scheduler = LinearDecay(optimizer, min_lr=self.config['scheduler']['min_lr'], num_epochs=self.config['num_epochs'], start_epoch=self.config['scheduler']['start_epoch']) else: raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) return scheduler @staticmethod def _get_adversarial_trainer(d_name, net_d, criterion_d): if d_name == 'no_gan': return GANFactory.create_model('NoGAN') elif d_name == 'patch_gan' or d_name == 'multi_scale': return GANFactory.create_model('SingleGAN', net_d, criterion_d) elif d_name == 'double_gan': return GANFactory.create_model('DoubleGAN', net_d, criterion_d) else: raise ValueError("Discriminator Network [%s] not recognized." % d_name) def _init_params(self): self.criterionG, criterionD = get_loss(self.config['model']) self.netG, netD = get_nets(self.config['model']) self.netG.to(self.device) self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD) self.model = get_model(self.config['model']) self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) self.optimizer_D = self._get_optim(self.adv_trainer.get_params()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self.scheduler_D = self._get_scheduler(self.optimizer_D)
class Trainer(object): def __init__(self, config): self.config = config self.train_dataset = self._get_dataset(config, 'train') self.val_dataset = self._get_dataset(config, 'test') self.experiment_name = config['experiment_desc'] + '_' + config[ 'model']['g_name'] self.experiment_name += '_content' + str(config['model']['content_coef'])+ '_feature' \ + str(config['model']['feature_coef']) + '_adv' \ + str(config['model']['adv_coef']) self.metric_counter = MetricCounter(self.experiment_name) self.warmup_epochs = config['warmup_num'] def train(self): self._init_params() for epoch in range(0, config['num_epochs']): if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): self.netG.module.unfreeze() self.optimizer_G = self._get_optim(self.netG.parameters()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self._run_epoch(epoch) self._validate(epoch) self.scheduler_G.step() self.scheduler_D.step() if self.metric_counter.update_best_model(): torch.save({'model': self.netG.state_dict()}, 'best_{}.h5'.format(self.experiment_name)) torch.save({'model': self.netG.state_dict()}, 'last_{}.h5'.format(self.experiment_name)) print(self.metric_counter.loss_message()) logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (self.experiment_name, epoch, self.metric_counter.loss_message())) def _run_epoch(self, epoch): self.metric_counter.clear() for param_group in self.optimizer_G.param_groups: lr = param_group['lr'] tq = tqdm.tqdm(self.train_dataset.dataloader) tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_D = self._update_d(outputs, targets) self.optimizer_G.zero_grad() loss_content, loss_vgg = self.criterionG[0]( outputs, targets), self.criterionG[1](outputs, targets) loss_adv = self.adv_trainer.lossG(outputs, targets) loss_G = self.config['model']['content_coef'] * loss_content + \ self.config['model']['adv_coef'] * loss_adv + \ self.config['model']['feature_coef'] * loss_vgg loss_G.backward() self.optimizer_G.step() self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) curr_psnr, curr_ssim = self.model.get_acc(outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) tq.set_postfix(loss=self.metric_counter.loss_message()) i += 1 if i == 3: self.metric_counter.images_to_tensorboard( [inputs[0], targets[0], outputs[0]], epoch) tq.close() self.metric_counter.write_to_tensorboard(epoch) def _validate(self, epoch): self.metric_counter.clear() tq = tqdm.tqdm(self.val_dataset.dataloader) tq.set_description('Validation') i = 0 for data in tq: i += 1 inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_content, loss_vgg = self.criterionG[0]( outputs, targets), self.criterionG[1](outputs, targets) loss_adv = self.adv_trainer.lossG(outputs, targets) loss_G = self.config['model']['content_coef'] * loss_content + \ self.config['model']['adv_coef'] * loss_adv + \ self.config['model']['feature_coef'] * loss_vgg self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_vgg.item()) curr_psnr, curr_ssim = self.model.get_acc(outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) # if i == 3: # self.metric_counter.images_to_tensorboard([inputs[0], targets[0], outputs[0]], epoch) tq.close() self.metric_counter.write_to_tensorboard(epoch, validation=True) def _get_dataset(self, config, filename): return CustomDataLoader(config, filename) def _update_d(self, outputs, targets): if self.config['model']['d_name'] == 'no_gan': return 0 self.optimizer_D.zero_grad() loss_D = self.config['model']['adv_coef'] * self.adv_trainer.lossD( outputs, targets) loss_D.backward(retain_graph=True) self.optimizer_D.step() return loss_D.item() def _get_optim(self, params): if self.config['optimizer']['name'] == 'adam': optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'sgd': optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'adadelta': optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr']) else: raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) return optimizer def _get_scheduler(self, optimizer): if self.config['scheduler']['name'] == 'plateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=self.config['scheduler']['patience'], factor=self.config['scheduler']['factor'], min_lr=self.config['scheduler']['min_lr']) elif self.config['optimizer']['name'] == 'sgdr': scheduler = WarmRestart(optimizer) elif self.config['scheduler']['name'] == 'linear': scheduler = LinearDecay( optimizer, min_lr=self.config['scheduler']['min_lr'], num_epochs=self.config['num_epochs'], start_epoch=self.config['scheduler']['start_epoch']) else: raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) return scheduler def _get_adversarial_trainer(self, D_name, netD, criterionD): if D_name == 'no_gan': return AdversarialTrainerFactory.createModel( 'NoAdversarialTrainer') elif D_name == 'patch_gan' or D_name == 'multi_scale': return AdversarialTrainerFactory.createModel( 'SingleAdversarialTrainer', netD, criterionD) elif D_name == 'double_gan': return AdversarialTrainerFactory.createModel( 'DoubleAdversarialTrainer', netD, criterionD) else: raise ValueError("Discriminator Network [%s] not recognized." % D_name) def _init_params(self): self.criterionG, criterionD = get_loss(self.config['model']) self.netG, netD = get_nets(self.config['model']) self.netG.cuda() self.adv_trainer = self._get_adversarial_trainer( self.config['model']['d_name'], netD, criterionD) self.model = get_model(self.config['model']) self.optimizer_G = self._get_optim( filter(lambda p: p.requires_grad, self.netG.parameters())) self.optimizer_D = self._get_optim(self.adv_trainer.get_params()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self.scheduler_D = self._get_scheduler(self.optimizer_D)
class Trainer(object): def __init__(self, config): self.config = config self.dataset = self._get_dataset(config["dataroot"], config["seq_length"]) dlen = len(self.dataset) splitlen = [int(0.8 * dlen), int(0.1 * dlen), 0] splitlen[2] = dlen - sum(splitlen) self.train_dataset, self.val_dataset, self.test_dataset = data.random_split( self.dataset, splitlen) self.experiment_name = f"{config['experiment_desc']}_{config['model']['model_n']}" self.metric_counter = MetricCounter(self.experiment_name, self.config["print_every"]) def test(self): PATH = f'model/pretrained/best_{self.config["experiment_desc"]}.pth' self.model.load_state_dict(torch.load(PATH)['model']) self.model.eval() self._validate(-1, True) def train(self): self._init_params() # PATH = f'model/pretrained/best_{self.config["experiment_desc"]}.pth' # self.model.load_state_dict(torch.load(PATH)['model']) self.model.cuda() for epoch in range(0, self.epochs): if self.config['model']['model_n'] == "semichar_rnn": self._run_epoch(epoch) self._validate(epoch) if self.config['model']['model_n'] == "seq2seq+attention": self._run_epoch_seq2seq(epoch) self._validate_seq2seq(epoch) if self.metric_counter.update_best_model(): torch.save( { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'dataset_params': self.dataset.get_params(), 'n_hidden': self.config['model']['n_hidden'], 'n_layers': self.config['model']['n_layers'] }, f'model/pretrained/best_{self.config["experiment_desc"]}.pth' ) print(self.metric_counter.loss_message()) logging.debug( f"Experiment Name: {self.config['experiment_desc']}, Epoch: {epoch}, Loss: {self.metric_counter.loss_message()}" ) def _run_epoch_seq2seq(self, epoch): self.metric_counter.clear() counter = 0 for lenX, X, leny, y, y1hot in self.loader_train: h = None lenX, perm_idx = lenX.sort(0, descending=True) X = X[perm_idx] y = y[perm_idx] leny = leny[perm_idx] y1hot = y1hot[perm_idx] X, y = X.float().cuda(), y.float().cuda() lenX, leny = lenX.cuda(), leny.cuda() y1hot = y1hot.cuda() self.optimizer.zero_grad() out = self.model(h, lenX, X, leny, y) tar = y1hot[:, :out.size(1)] tar = tar.contiguous().view(tar.nelement()) cur = out.contiguous().view(tar.nelement(), -1) loss = self.loss_fn(cur, tar) loss.backward() acc = self.calc_acc(cur, tar) nn.utils.clip_grad_norm_(self.model.parameters(), self.config['clip']) self.optimizer.step() self.metric_counter.add_losses(loss) self.metric_counter.add_acc(acc) if not counter % self.config['print_every']: print( f"Epoch: {epoch}; Train Step: {counter}: {self.metric_counter.loss_message()}" ) counter += 1 self.metric_counter.write_to_tensorboard(epoch) def _validate_seq2seq(self, epoch): self.metric_counter.clear() counter = 0 self.model.eval() loader = self.loader_val for lenX, X, leny, y, y1hot in loader: h = None lenX, perm_idx = lenX.sort(0, descending=True) X = X[perm_idx] y = y[perm_idx] leny = leny[perm_idx] y1hot = y1hot[perm_idx] X, y = X.float().cuda(), y.float().cuda() lenX, leny = lenX.cuda(), leny.cuda() y1hot = y1hot.cuda() out = self.model(h, lenX, X, leny, y) tar = y1hot[:, :out.size(1)] tar = tar.contiguous().view(tar.nelement()) cur = out.contiguous().view(tar.nelement(), -1) loss = self.loss_fn(cur, tar) acc = self.calc_acc(cur, tar) self.metric_counter.add_losses(loss) self.metric_counter.add_acc(acc) if not counter % self.config['print_every']: print( f"Epoch: {epoch}; Valid Step: {counter}: {self.metric_counter.loss_message()}" ) counter += 1 self.metric_counter.write_to_tensorboard(epoch, validation=True) print(self.metric_counter.loss_message()) self.model.train() def _run_epoch(self, epoch): self.metric_counter.clear() h = None counter = 0 for X, y in self.loader_train: X, y = X.cuda(), y.cuda() self.optimizer.zero_grad() if h is not None: h = tuple([x[:, :X.size(0)].contiguous() for x in h]) output, h = self.model(X, h) loss = self.loss_fn(output, y.view(y.nelement()).long()) loss.backward() h = tuple([i.detach_() for i in h]) acc = self.calc_acc(output, y.view(y.nelement())) nn.utils.clip_grad_norm_(self.model.parameters(), self.config['clip']) self.optimizer.step() self.metric_counter.add_losses(loss) self.metric_counter.add_acc(acc) if not counter % self.config['print_every']: print( f"Epoch: {epoch}; Train Step: {counter}: {self.metric_counter.loss_message()}" ) counter += 1 self.metric_counter.write_to_tensorboard(epoch) def _validate(self, epoch, test=False): h = None self.model.eval() self.metric_counter.clear() counter = 0 loader = (self.loader_test if test else self.loader_val) for X, y in loader: X, y = X.cuda(), y.cuda() if h is not None: h = tuple([x[:, :X.size(0)].contiguous() for x in h]) output, h = self.model(X, h) loss = self.loss_fn(output, y.view(y.nelement()).long()) acc = self.calc_acc(output, y.view(y.nelement())) self.metric_counter.add_losses(loss) self.metric_counter.add_acc(acc) if not counter % self.config['print_every']: print( f"Epoch: {epoch}; Valid Step: {counter}: {self.metric_counter.loss_message()}" ) counter += 1 if not test: self.metric_counter.write_to_tensorboard(epoch, validation=True) else: print(self.metric_counter.loss_message()) self.model.train() def calc_acc(self, y_pred, y): acc_output = y_pred.cpu().detach().numpy().astype("int32") acc = np.equal(acc_output.argmax(1), y.cpu().numpy()) return acc.sum() / acc_output.shape[0] def _get_dataset(self, dataroot, seq_length): if self.config['dataset'] == "autocomplete": return AutoCompleteDataset(dataroot, seq_length) if self.config['model']['model_n'] == "semichar_rnn": return SemicharDataset(dataroot, seq_length) if self.config['model']['model_n'] == "seq2seq+attention": return seq2seqDataset(dataroot, seq_length, 50, 1000) def _get_optim(self, params): if self.config['optimizer']['name'] == 'adam': optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'sgd': optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) else: raise ValueError( f"Optimizer {self.config['optimizer']['name']} not recognized." ) return optimizer def _init_params(self): self.loader_train = data.DataLoader( self.train_dataset, batch_size=self.config['batch_size'], shuffle=True) self.loader_val = data.DataLoader(self.val_dataset, batch_size=self.config['batch_size'], shuffle=False) self.loader_test = data.DataLoader( self.test_dataset, batch_size=self.config['batch_size'], shuffle=False) self.model = get_model(self.config['model']['model_n'], self.dataset.get_params(), self.config['model']['n_hidden'], self.config['model']['n_layers'], self.config['model']['drop_prob'], self.config['model']['grad_clip'], self.config['model']['attn_model'], self.config['model']['dim'], self.config['model']['vs']) self.epochs = self.config['num_epochs'] self.optimizer = self._get_optim(self.model.parameters()) self.loss_fn = get_loss(self.config['model']['loss'])