def unit_train(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data try: if self.use_cuda: xs, ys = xs.cuda(non_blocking=True), ys.cuda(non_blocking=True) ys_hat, ys_hat_lens, ys = self.model(xs, frame_lens, ys, label_lens) if ys_hat is None: logger.debug("the batch includes a data with label_lens > max_seq_lens, so skipped") return None if self.fp16: ys_hat = ys_hat.float() loss = self.loss(ys_hat.transpose(1, 2), ys.long()) loss_value = loss.item() self.optimizer.zero_grad() if self.fp16: #self.optimizer.backward(loss) #self.optimizer.clip_master_grads(self.max_norm) with self.optimizer.scale_loss(loss) as scaled_loss: scaled_loss.backward() else: loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm) if is_distributed(): self.average_gradients() self.optimizer.step() if self.use_cuda: torch.cuda.synchronize() del loss return loss_value except Exception as e: print(e) print(filenames, frame_lens, label_lens) raise
def save(self, file_path, **kwargs): Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True) logger.debug(f"saving the model to {file_path}") if self.states is None: self.states = dict() self.states.update(kwargs) self.states["epoch"] = self.epoch self.states["opt_type"] = self.opt_type if is_distributed(): model_state_dict = self.model.state_dict() strip_prefix = 9 if self.fp16 else 7 # remove "module.1." prefix from keys self.states["model"] = { k[strip_prefix:]: v for k, v in model_state_dict.items() } else: self.states["model"] = self.model.state_dict() self.states["optimizer"] = self.optimizer.state_dict() if self.lr_scheduler is not None: self.states["lr_scheduler"] = self.lr_scheduler.state_dict() self.save_hook() torch.save(self.states, file_path)
def load(self, file_path): if isinstance(file_path, str): file_path = Path(file_path) if not file_path.exists(): logger.error(f"no such file {file_path} exists") sys.exit(1) logger.debug(f"loading the model from {file_path}") to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu" self.states = torch.load(file_path, map_location=to_device)
def unit_train(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data try: batch_size = xs.size(0) if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat, frame_lens = self.model(xs, frame_lens) if frame_lens.lt(2 * label_lens).nonzero().numel(): logger.debug( "the batch includes a data with frame_lens < 2*label_lens, so skipped" ) return None if self.fp16: ys_hat = ys_hat.float() ys_hat = ys_hat.transpose(0, 1).contiguous() # TxNxH #torch.set_printoptions(threshold=5000000) #print(ys_hat.shape, frame_lens, ys.shape, label_lens) #print(onehot2int(ys_hat).squeeze(), ys) d = frame_lens.float() #d = frame_lens.sum().float() if self.use_cuda: d = d.cuda() loss = (self.loss(ys_hat, ys, frame_lens, label_lens) / d).mean() #loss = self.loss(ys_hat, ys, frame_lens, label_lens).div_(d) #loss = self.loss(ys_hat, ys, frame_lens, label_lens) if torch.isnan(loss) or loss.item() == float( "inf") or loss.item() == -float("inf"): logger.warning( "received an nan/inf loss: probably frame_lens < label_lens or the learning rate is too high" ) #loss.mul_(0.) return None loss_value = loss.item() self.optimizer.zero_grad() if self.fp16: #self.optimizer.backward(loss) #self.optimizer.clip_master_grads(self.max_norm) with self.optimizer.scale_loss(loss) as scaled_loss: scaled_loss.backward() else: loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm) #if is_distributed(): # self.average_gradients() self.optimizer.step() if self.use_cuda: torch.cuda.synchronize() del loss return loss_value except Exception as e: print(e) print(filenames, frame_lens, label_lens) raise
def __init__(self, model, use_cuda=False, continue_from=None, verbose=False, *args, **kwargs): assert continue_from is not None self.use_cuda = use_cuda self.verbose = verbose # load from args self.model = model if self.use_cuda: logger.debug("using cuda") self.model.cuda() self.load(continue_from) # prepare kaldi latgen decoder self.decoder = LatGenCTCDecoder()
def train_epoch(self, data_loader): self.model.train() meter_loss = tnt.meter.MovingAverageValueMeter( len(data_loader) // 100 + 1) #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True) def plot_scalar(i, loss, title="train"): #if self.lr_scheduler is not None: # self.lr_scheduler.step() x = self.epoch + i / len(data_loader) if logger.visdom is not None: opts = { 'xlabel': 'epoch', 'ylabel': 'loss', } logger.visdom.add_point(title=title, x=x, y=loss, **opts) if logger.tensorboard is not None: logger.tensorboard.add_graph(self.model, xs) xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True) logger.tensorboard.add_image('xs', x, xs_img) ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True) logger.tensorboard.add_image('ys_hat', x, ys_hat_img) logger.tensorboard.add_scalars(title, x, { 'loss': loss, }) if self.lr_scheduler is not None: self.lr_scheduler.step() logger.debug( f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}") if is_distributed() and data_loader.sampler is not None: data_loader.sampler.set_epoch(self.epoch) ckpts = iter(len(data_loader) * np.arange(0.1, 1.1, 0.1)) ckpt = next(ckpts) self.train_loop_before_hook() # count the number of supervised batches seen in this epoch t = tqdm(enumerate(data_loader), total=len(data_loader), desc="training", ncols=p.NCOLS) for i, (data) in t: loss_value = self.unit_train(data) if loss_value is not None: meter_loss.add(loss_value) t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})") t.refresh() #self.meter_accuracy.add(ys_int, ys) #self.meter_confusion.add(ys_int, ys) if i > ckpt: plot_scalar(i, meter_loss.value()[0]) if self.checkpoint: logger.info( f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: " f"{meter_loss.value()[0]:5.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save( self.__get_model_name( f"epoch_{self.epoch:03d}_ckpt_{i:07d}")) ckpt = next(ckpts) #input("press key to continue") plot_scalar(i, meter_loss.value()[0]) self.epoch += 1 logger.info(f"epoch {self.epoch:03d}: " f"training loss {meter_loss.value()[0]:5.3f} ") #f"training accuracy {meter_accuracy.value()[0]:6.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}")) self.__remove_ckpt_files(self.epoch - 1) self.train_loop_after_hook()
def __init__(self, model, amp_handle=None, init_lr=1e-2, max_norm=100, use_cuda=False, fp16=False, log_dir='logs', model_prefix='model', checkpoint=False, continue_from=None, opt_type=None, *args, **kwargs): if fp16: import apex.parallel from apex import amp if not use_cuda: raise RuntimeError self.amp_handle = amp_handle # training parameters self.init_lr = init_lr self.max_norm = max_norm self.use_cuda = use_cuda self.fp16 = fp16 self.log_dir = log_dir self.model_prefix = model_prefix self.checkpoint = checkpoint self.opt_type = opt_type self.epoch = 0 self.states = None # load from pre-trained model if needed if continue_from is not None: self.load(continue_from) # setup model self.model = model if self.use_cuda: logger.debug("using cuda") self.model.cuda() # setup loss #self.loss = nn.CTCLoss(blank=0, reduction='none') self.loss = wp.CTCLoss(blank=0, length_average=True) # setup optimizer if opt_type is None: # for test only self.optimizer = None self.lr_scheduler = None else: assert opt_type in OPTIMIZER_TYPES parameters = self.model.parameters() if opt_type == "sgdr": logger.debug("using SGDR") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9, weight_decay=5e-4) #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5) self.lr_scheduler = CosineAnnealingWithRestartsLR( self.optimizer, T_max=2, T_mult=2) elif opt_type == "adamwr": logger.debug("using AdamWR") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=5e-4) self.lr_scheduler = CosineAnnealingWithRestartsLR( self.optimizer, T_max=2, T_mult=2) elif opt_type == "adam": logger.debug("using Adam") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=5e-4) self.lr_scheduler = None elif opt_type == "rmsprop": logger.debug("using RMSprop") self.optimizer = torch.optim.RMSprop(parameters, lr=self.init_lr, alpha=0.95, eps=1e-8, weight_decay=5e-4, centered=True) self.lr_scheduler = None # setup decoder for test self.decoder = LatGenCTCDecoder() self.labeler = self.decoder.labeler # FP16 and distributed after load if self.fp16: #self.model = network_to_half(self.model) #self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.) self.optimizer = self.amp_handle.wrap_optimizer(self.optimizer) if is_distributed(): if self.use_cuda: local_rank = torch.cuda.current_device() if fp16: self.model = apex.parallel.DistributedDataParallel( self.model) else: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[local_rank], output_device=local_rank) else: self.model = nn.parallel.DistributedDataParallel(self.model) if self.states is not None: self.restore_state()
def train_loop_before_hook(self): self.tfr_scheduler.step() logger.debug(f"current tfr = {self.model.tfr:.3e}")
def train_epoch(self, data_loader): self.model.train() meter_loss = tnt.meter.MovingAverageValueMeter( len(data_loader) // 100 + 1) #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True) if self.lr_scheduler is not None: self.lr_scheduler.step() logger.debug( f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}") if is_distributed() and data_loader.sampler is not None: data_loader.sampler.set_epoch(self.epoch) ckpt_step = 0.1 ckpts = iter( len(data_loader) * np.arange(ckpt_step, 1 + ckpt_step, ckpt_step)) def plot_graphs(loss, data_iter=0, title="train", stats=False): #if self.lr_scheduler is not None: # self.lr_scheduler.step() x = self.epoch + data_iter / len(data_loader) self.global_step = int(x / ckpt_step) if logger.visdom is not None: opts = { 'xlabel': 'epoch', 'ylabel': 'loss', } logger.visdom.add_point(title=title, x=x, y=loss, **opts) if logger.tensorboard is not None: #logger.tensorboard.add_graph(self.model, xs) #xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True) #logger.tensorboard.add_image('xs', self.global_step, xs_img) #ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True) #logger.tensorboard.add_image('ys_hat', self.global_step, ys_hat_img) logger.tensorboard.add_scalars(title, self.global_step, { 'loss': loss, }) if stats: for name, param in self.model.named_parameters(): logger.tensorboard.add_histogram( name, self.global_step, param.clone().cpu().data.numpy()) self.train_loop_before_hook() ckpt = next(ckpts) t = tqdm(enumerate(data_loader), total=len(data_loader), desc="training", ncols=p.NCOLS) for i, (data) in t: loss_value = self.unit_train(data) if loss_value is not None: meter_loss.add(loss_value) t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})") t.refresh() #self.meter_accuracy.add(ys_int, ys) #self.meter_confusion.add(ys_int, ys) if i > ckpt: plot_graphs(meter_loss.value()[0], i) if self.checkpoint: logger.info( f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: " f"{meter_loss.value()[0]:5.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save( self.__get_model_name( f"epoch_{self.epoch:03d}_ckpt_{i:07d}")) self.train_loop_checkpoint_hook() ckpt = next(ckpts) self.epoch += 1 logger.info(f"epoch {self.epoch:03d}: " f"training loss {meter_loss.value()[0]:5.3f} ") #f"training accuracy {meter_accuracy.value()[0]:6.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}")) self.__remove_ckpt_files(self.epoch - 1) plot_graphs(meter_loss.value()[0], stats=True) self.train_loop_after_hook()
def __init__(self, model, init_lr=1e-4, max_norm=400, use_cuda=False, fp16=False, log_dir='logs', model_prefix='model', checkpoint=False, continue_from=None, opt_type="sgdr", *args, **kwargs): if fp16: if not use_cuda: raise RuntimeError # training parameters self.init_lr = init_lr self.max_norm = max_norm self.use_cuda = use_cuda self.fp16 = fp16 self.log_dir = log_dir self.model_prefix = model_prefix self.checkpoint = checkpoint self.epoch = 0 # prepare visdom if logger.visdom is not None: logger.visdom.add_plot(title=f'train', xlabel='epoch', ylabel='loss') logger.visdom.add_plot(title=f'validate', xlabel='epoch', ylabel='LER') # setup model self.model = model if self.use_cuda: logger.debug("using cuda") self.model.cuda() # setup loss self.loss = CTCLoss(blank=0, size_average=True, length_average=True) # setup optimizer assert opt_type in OPTIMIZER_TYPES parameters = self.model.parameters() if opt_type == "sgd": logger.debug("using SGD") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9) self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=5) elif opt_type == "sgdr": logger.debug("using SGDR") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9) #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5) self.lr_scheduler = CosineAnnealingWithRestartsLR(self.optimizer, T_max=5, T_mult=2) elif opt_type == "adam": logger.debug("using AdamW") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0005, l2_reg=False) self.lr_scheduler = None # setup decoder for test self.decoder = LatGenCTCDecoder() # load from pre-trained model if needed if continue_from is not None: self.load(continue_from) # FP16 and distributed after load if self.fp16: self.model = network_to_half(self.model) self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.) if is_distributed(): if self.use_cuda: local_rank = torch.cuda.current_device() if fp16: self.model = apex.parallel.DistributedDataParallel( self.model) else: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[local_rank], output_device=local_rank) else: self.model = nn.parallel.DistributedDataParallel(self.model)
def train_epoch(self, data_loader): self.model.train() num_ckpt = int(np.ceil(len(data_loader) / 10)) meter_loss = tnt.meter.MovingAverageValueMeter( len(data_loader) // 100 + 1) #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True) if self.lr_scheduler is not None: self.lr_scheduler.step() logger.debug(f"current lr = {self.lr_scheduler.get_lr()}") if is_distributed() and data_loader.sampler is not None: data_loader.sampler.set_epoch(self.epoch) # count the number of supervised batches seen in this epoch t = tqdm(enumerate(data_loader), total=len(data_loader), desc="training") for i, (data) in t: loss_value = self.unit_train(data) meter_loss.add(loss_value) t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})") t.refresh() #self.meter_accuracy.add(ys_int, ys) #self.meter_confusion.add(ys_int, ys) if 0 < i < len(data_loader) and i % num_ckpt == 0: if not is_distributed() or (is_distributed() and dist.get_rank() == 0): title = "train" x = self.epoch + i / len(data_loader) if logger.visdom is not None: logger.visdom.add_point(title=title, x=x, y=meter_loss.value()[0]) if logger.tensorboard is not None: logger.tensorboard.add_graph(self.model, xs) xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True) logger.tensorboard.add_image('xs', x, xs_img) ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True) logger.tensorboard.add_image('ys_hat', x, ys_hat_img) logger.tensorboard.add_scalars( title, x, { 'loss': meter_loss.value()[0], }) if self.checkpoint: logger.info( f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: " f"{meter_loss.value()[0]:5.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save( self.__get_model_name( f"epoch_{self.epoch:03d}_ckpt_{i:07d}")) #input("press key to continue") self.epoch += 1 logger.info(f"epoch {self.epoch:03d}: " f"training loss {meter_loss.value()[0]:5.3f} ") #f"training accuracy {meter_accuracy.value()[0]:6.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}")) self.__remove_ckpt_files(self.epoch - 1)