class Trainer: def __init__(self, model, amp_handle=None, init_lr=1e-2, max_norm=100, use_cuda=False, fp16=False, log_dir='logs', model_prefix='model', checkpoint=False, continue_from=None, opt_type=None, *args, **kwargs): if fp16: import apex.parallel from apex import amp if not use_cuda: raise RuntimeError self.amp_handle = amp_handle # training parameters self.init_lr = init_lr self.max_norm = max_norm self.use_cuda = use_cuda self.fp16 = fp16 self.log_dir = log_dir self.model_prefix = model_prefix self.checkpoint = checkpoint self.opt_type = opt_type self.epoch = 0 self.states = None # load from pre-trained model if needed if continue_from is not None: self.load(continue_from) # setup model self.model = model if self.use_cuda: logger.debug("using cuda") self.model.cuda() # setup loss #self.loss = nn.CTCLoss(blank=0, reduction='none') self.loss = wp.CTCLoss(blank=0, length_average=True) # setup optimizer if opt_type is None: # for test only self.optimizer = None self.lr_scheduler = None else: assert opt_type in OPTIMIZER_TYPES parameters = self.model.parameters() if opt_type == "sgdr": logger.debug("using SGDR") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9, weight_decay=5e-4) #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5) self.lr_scheduler = CosineAnnealingWithRestartsLR( self.optimizer, T_max=2, T_mult=2) elif opt_type == "adamwr": logger.debug("using AdamWR") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=5e-4) self.lr_scheduler = CosineAnnealingWithRestartsLR( self.optimizer, T_max=2, T_mult=2) elif opt_type == "adam": logger.debug("using Adam") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=5e-4) self.lr_scheduler = None elif opt_type == "rmsprop": logger.debug("using RMSprop") self.optimizer = torch.optim.RMSprop(parameters, lr=self.init_lr, alpha=0.95, eps=1e-8, weight_decay=5e-4, centered=True) self.lr_scheduler = None # setup decoder for test self.decoder = LatGenCTCDecoder() self.labeler = self.decoder.labeler # FP16 and distributed after load if self.fp16: #self.model = network_to_half(self.model) #self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.) self.optimizer = self.amp_handle.wrap_optimizer(self.optimizer) if is_distributed(): if self.use_cuda: local_rank = torch.cuda.current_device() if fp16: self.model = apex.parallel.DistributedDataParallel( self.model) else: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[local_rank], output_device=local_rank) else: self.model = nn.parallel.DistributedDataParallel(self.model) if self.states is not None: self.restore_state() def __get_model_name(self, desc): return str(get_model_file_path(self.log_dir, self.model_prefix, desc)) def __remove_ckpt_files(self, epoch): for ckpt in Path(self.log_dir).rglob(f"*_epoch_{epoch:03d}_ckpt_*"): ckpt.unlink() def train_loop_before_hook(self): pass def train_loop_after_hook(self): pass def unit_train(self, data): raise NotImplementedError #def average_gradients(self): # if not is_distributed(): # return # size = float(dist.get_world_size()) # for param in self.model.parameters(): # dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, async_op=True) # param.grad.data /= size def train_epoch(self, data_loader): self.model.train() meter_loss = tnt.meter.MovingAverageValueMeter( len(data_loader) // 100 + 1) #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True) def plot_scalar(i, loss, title="train"): #if self.lr_scheduler is not None: # self.lr_scheduler.step() x = self.epoch + i / len(data_loader) if logger.visdom is not None: opts = { 'xlabel': 'epoch', 'ylabel': 'loss', } logger.visdom.add_point(title=title, x=x, y=loss, **opts) if logger.tensorboard is not None: logger.tensorboard.add_graph(self.model, xs) xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True) logger.tensorboard.add_image('xs', x, xs_img) ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True) logger.tensorboard.add_image('ys_hat', x, ys_hat_img) logger.tensorboard.add_scalars(title, x, { 'loss': loss, }) if self.lr_scheduler is not None: self.lr_scheduler.step() logger.debug( f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}") if is_distributed() and data_loader.sampler is not None: data_loader.sampler.set_epoch(self.epoch) ckpts = iter(len(data_loader) * np.arange(0.1, 1.1, 0.1)) ckpt = next(ckpts) self.train_loop_before_hook() # count the number of supervised batches seen in this epoch t = tqdm(enumerate(data_loader), total=len(data_loader), desc="training", ncols=p.NCOLS) for i, (data) in t: loss_value = self.unit_train(data) if loss_value is not None: meter_loss.add(loss_value) t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})") t.refresh() #self.meter_accuracy.add(ys_int, ys) #self.meter_confusion.add(ys_int, ys) if i > ckpt: plot_scalar(i, meter_loss.value()[0]) if self.checkpoint: logger.info( f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: " f"{meter_loss.value()[0]:5.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save( self.__get_model_name( f"epoch_{self.epoch:03d}_ckpt_{i:07d}")) ckpt = next(ckpts) #input("press key to continue") plot_scalar(i, meter_loss.value()[0]) self.epoch += 1 logger.info(f"epoch {self.epoch:03d}: " f"training loss {meter_loss.value()[0]:5.3f} ") #f"training accuracy {meter_accuracy.value()[0]:6.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}")) self.__remove_ckpt_files(self.epoch - 1) self.train_loop_after_hook() def unit_validate(self, data): raise NotImplementedError def validate(self, data_loader): "validate with label error rate by the edit distance between hyps and refs" self.model.eval() with torch.no_grad(): N, D = 0, 0 t = tqdm(enumerate(data_loader), total=len(data_loader), desc="validating", ncols=p.NCOLS) for i, (data) in t: hyps, refs = self.unit_validate(data) # calculate ler N += self.edit_distance(refs, hyps) D += sum(len(r) for r in refs) ler = N * 100. / D t.set_description(f"validating (LER: {ler:.2f} %)") t.refresh() logger.info( f"validating at epoch {self.epoch:03d}: LER {ler:.2f} %") title = f"validate" x = self.epoch - 1 + i / len(data_loader) if logger.visdom is not None: opts = { 'xlabel': 'epoch', 'ylabel': 'LER', } logger.visdom.add_point(title=title, x=x, y=ler, **opts) if logger.tensorboard is not None: logger.tensorboard.add_scalars(title, x, { 'LER': ler, }) def unit_test(self, data): raise NotImplementedError def test(self, data_loader): "test with word error rate by the edit distance between hyps and refs" self.model.eval() with torch.no_grad(): N, D = 0, 0 t = tqdm(enumerate(data_loader), total=len(data_loader), desc="testing", ncols=p.NCOLS) for i, (data) in t: hyps, refs = self.unit_test(data) # calculate wer N += self.edit_distance(refs, hyps) D += sum(len(r) for r in refs) wer = N * 100. / D t.set_description(f"testing (WER: {wer:.2f} %)") t.refresh() logger.info(f"testing at epoch {self.epoch:03d}: WER {wer:.2f} %") def edit_distance(self, refs, hyps): assert len(refs) == len(hyps) n = 0 for ref, hyp in zip(refs, hyps): r = [chr(c) for c in ref] h = [chr(c) for c in hyp] n += Lev.distance(''.join(r), ''.join(h)) return n def target_to_loglikes(self, ys, label_lens): max_len = max(label_lens.tolist()) num_classes = self.labeler.get_num_labels() ys_hat = [ torch.cat((torch.zeros(1).int(), ys[s:s + l], torch.zeros(max_len - l).int())) for s, l in zip([0] + label_lens[:-1].cumsum(0).tolist(), label_lens.tolist()) ] ys_hat = [ int2onehot(torch.IntTensor(z), num_classes, floor=1e-3) for z in ys_hat ] ys_hat = torch.stack(ys_hat) ys_hat = torch.log(ys_hat) return ys_hat def save_hook(self): pass def save(self, file_path, **kwargs): Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True) logger.debug(f"saving the model to {file_path}") if self.states is None: self.states = dict() self.states.update(kwargs) self.states["epoch"] = self.epoch self.states["opt_type"] = self.opt_type if is_distributed(): model_state_dict = self.model.state_dict() strip_prefix = 9 if self.fp16 else 7 # remove "module.1." prefix from keys self.states["model"] = { k[strip_prefix:]: v for k, v in model_state_dict.items() } else: self.states["model"] = self.model.state_dict() self.states["optimizer"] = self.optimizer.state_dict() if self.lr_scheduler is not None: self.states["lr_scheduler"] = self.lr_scheduler.state_dict() self.save_hook() torch.save(self.states, file_path) def load(self, file_path): if isinstance(file_path, str): file_path = Path(file_path) if not file_path.exists(): logger.error(f"no such file {file_path} exists") sys.exit(1) logger.debug(f"loading the model from {file_path}") to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu" self.states = torch.load(file_path, map_location=to_device) def restore_state(self): self.epoch = self.states["epoch"] if is_distributed(): self.model.load_state_dict( {f"module.{k}": v for k, v in self.states["model"].items()}) else: self.model.load_state_dict(self.states["model"]) if "opt_type" in self.states and self.opt_type == self.states[ "opt_type"]: self.optimizer.load_state_dict(self.states["optimizer"]) if self.lr_scheduler is not None and "lr_scheduler" in self.states: self.lr_scheduler.load_state_dict(self.states["lr_scheduler"])
class Trainer: def __init__(self, model, init_lr=1e-4, max_norm=400, use_cuda=False, fp16=False, log_dir='logs', model_prefix='model', checkpoint=False, continue_from=None, opt_type="sgdr", *args, **kwargs): if fp16: if not use_cuda: raise RuntimeError # training parameters self.init_lr = init_lr self.max_norm = max_norm self.use_cuda = use_cuda self.fp16 = fp16 self.log_dir = log_dir self.model_prefix = model_prefix self.checkpoint = checkpoint self.epoch = 0 # prepare visdom if logger.visdom is not None: logger.visdom.add_plot(title=f'train', xlabel='epoch', ylabel='loss') logger.visdom.add_plot(title=f'validate', xlabel='epoch', ylabel='LER') # setup model self.model = model if self.use_cuda: logger.debug("using cuda") self.model.cuda() # setup loss self.loss = CTCLoss(blank=0, size_average=True, length_average=True) # setup optimizer assert opt_type in OPTIMIZER_TYPES parameters = self.model.parameters() if opt_type == "sgd": logger.debug("using SGD") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9) self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=5) elif opt_type == "sgdr": logger.debug("using SGDR") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9) #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5) self.lr_scheduler = CosineAnnealingWithRestartsLR(self.optimizer, T_max=5, T_mult=2) elif opt_type == "adam": logger.debug("using AdamW") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0005, l2_reg=False) self.lr_scheduler = None # setup decoder for test self.decoder = LatGenCTCDecoder() # load from pre-trained model if needed if continue_from is not None: self.load(continue_from) # FP16 and distributed after load if self.fp16: self.model = network_to_half(self.model) self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.) if is_distributed(): if self.use_cuda: local_rank = torch.cuda.current_device() if fp16: self.model = apex.parallel.DistributedDataParallel( self.model) else: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[local_rank], output_device=local_rank) else: self.model = nn.parallel.DistributedDataParallel(self.model) def __get_model_name(self, desc): return str(get_model_file_path(self.log_dir, self.model_prefix, desc)) def __remove_ckpt_files(self, epoch): for ckpt in Path(self.log_dir).rglob(f"*_epoch_{epoch:03d}_ckpt_*"): ckpt.unlink() def unit_train(self, data): raise NotImplementedError def train_epoch(self, data_loader): self.model.train() num_ckpt = int(np.ceil(len(data_loader) / 10)) meter_loss = tnt.meter.MovingAverageValueMeter( len(data_loader) // 100 + 1) #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True) if self.lr_scheduler is not None: self.lr_scheduler.step() logger.debug(f"current lr = {self.lr_scheduler.get_lr()}") if is_distributed() and data_loader.sampler is not None: data_loader.sampler.set_epoch(self.epoch) # count the number of supervised batches seen in this epoch t = tqdm(enumerate(data_loader), total=len(data_loader), desc="training") for i, (data) in t: loss_value = self.unit_train(data) meter_loss.add(loss_value) t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})") t.refresh() #self.meter_accuracy.add(ys_int, ys) #self.meter_confusion.add(ys_int, ys) if 0 < i < len(data_loader) and i % num_ckpt == 0: if not is_distributed() or (is_distributed() and dist.get_rank() == 0): title = "train" x = self.epoch + i / len(data_loader) if logger.visdom is not None: logger.visdom.add_point(title=title, x=x, y=meter_loss.value()[0]) if logger.tensorboard is not None: logger.tensorboard.add_graph(self.model, xs) xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True) logger.tensorboard.add_image('xs', x, xs_img) ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True) logger.tensorboard.add_image('ys_hat', x, ys_hat_img) logger.tensorboard.add_scalars( title, x, { 'loss': meter_loss.value()[0], }) if self.checkpoint: logger.info( f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: " f"{meter_loss.value()[0]:5.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save( self.__get_model_name( f"epoch_{self.epoch:03d}_ckpt_{i:07d}")) #input("press key to continue") self.epoch += 1 logger.info(f"epoch {self.epoch:03d}: " f"training loss {meter_loss.value()[0]:5.3f} ") #f"training accuracy {meter_accuracy.value()[0]:6.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}")) self.__remove_ckpt_files(self.epoch - 1) def unit_validate(self, data): raise NotImplementedError def validate(self, data_loader): "validate with label error rate by the edit distance between hyps and refs" self.model.eval() with torch.no_grad(): N, D = 0, 0 t = tqdm(enumerate(data_loader), total=len(data_loader), desc="validating") for i, (data) in t: hyps, refs = self.unit_validate(data) # calculate ler N += self.edit_distance(refs, hyps) D += sum(len(r) for r in refs) ler = N * 100. / D t.set_description(f"validating (LER: {ler:.2f} %)") t.refresh() logger.info( f"validating at epoch {self.epoch:03d}: LER {ler:.2f} %") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): title = f"validate" x = self.epoch - 1 + i / len(data_loader) if logger.visdom is not None: logger.visdom.add_point(title=title, x=x, y=ler) if logger.tensorboard is not None: logger.tensorboard.add_scalars(title, x, { 'LER': ler, }) def unit_test(self, data): raise NotImplementedError def test(self, data_loader): "test with word error rate by the edit distance between hyps and refs" self.model.eval() with torch.no_grad(): N, D = 0, 0 t = tqdm(enumerate(data_loader), total=len(data_loader), desc="testing") for i, (data) in t: hyps, refs = self.unit_test(data) # calculate wer N += self.edit_distance(refs, hyps) D += sum(len(r) for r in refs) wer = N * 100. / D t.set_description(f"testing (WER: {wer:.2f} %)") t.refresh() logger.info(f"testing at epoch {self.epoch:03d}: WER {wer:.2f} %") def edit_distance(self, refs, hyps): assert len(refs) == len(hyps) n = 0 for ref, hyp in zip(refs, hyps): r = [chr(c) for c in ref] h = [chr(c) for c in hyp] n += Lev.distance(''.join(r), ''.join(h)) return n def save(self, file_path, **kwargs): Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True) logger.info(f"saving the model to {file_path}") states = kwargs states["epoch"] = self.epoch if is_distributed(): model_state_dict = self.model.state_dict() strip_prefix = 9 if self.fp16 else 7 # remove "module.1." prefix from keys states["model"] = { k[strip_prefix:]: v for k, v in model_state_dict.items() } else: states["model"] = self.model.state_dict() states["optimizer"] = self.optimizer.state_dict() states["lr_scheduler"] = self.lr_scheduler.state_dict() torch.save(states, file_path) def load(self, file_path): if isinstance(file_path, str): file_path = Path(file_path) if not file_path.exists(): logger.error(f"no such file {file_path} exists") sys.exit(1) logger.info(f"loading the model from {file_path}") to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu" states = torch.load(file_path, map_location=to_device) self.epoch = states["epoch"] self.model.load_state_dict(states["model"]) self.optimizer.load_state_dict(states["optimizer"]) self.lr_scheduler.load_state_dict(states["lr_scheduler"])