Beispiel #1
0
 def unit_train(self, data):
     xs, ys, frame_lens, label_lens, filenames, _ = data
     try:
         if self.use_cuda:
             xs, ys = xs.cuda(non_blocking=True), ys.cuda(non_blocking=True)
         ys_hat, ys_hat_lens, ys = self.model(xs, frame_lens, ys, label_lens)
         if ys_hat is None:
             logger.debug("the batch includes a data with label_lens > max_seq_lens, so skipped")
             return None
         if self.fp16:
             ys_hat = ys_hat.float()
         loss = self.loss(ys_hat.transpose(1, 2), ys.long())
         loss_value = loss.item()
         self.optimizer.zero_grad()
         if self.fp16:
             #self.optimizer.backward(loss)
             #self.optimizer.clip_master_grads(self.max_norm)
             with self.optimizer.scale_loss(loss) as scaled_loss:
                 scaled_loss.backward()
         else:
             loss.backward()
             nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
         if is_distributed():
             self.average_gradients()
         self.optimizer.step()
         if self.use_cuda:
             torch.cuda.synchronize()
         del loss
         return loss_value
     except Exception as e:
         print(e)
         print(filenames, frame_lens, label_lens)
         raise
Beispiel #2
0
    def save(self, file_path, **kwargs):
        Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True)
        logger.debug(f"saving the model to {file_path}")

        if self.states is None:
            self.states = dict()
        self.states.update(kwargs)
        self.states["epoch"] = self.epoch
        self.states["opt_type"] = self.opt_type
        if is_distributed():
            model_state_dict = self.model.state_dict()
            strip_prefix = 9 if self.fp16 else 7
            # remove "module.1." prefix from keys
            self.states["model"] = {
                k[strip_prefix:]: v
                for k, v in model_state_dict.items()
            }
        else:
            self.states["model"] = self.model.state_dict()
        self.states["optimizer"] = self.optimizer.state_dict()
        if self.lr_scheduler is not None:
            self.states["lr_scheduler"] = self.lr_scheduler.state_dict()

        self.save_hook()
        torch.save(self.states, file_path)
Beispiel #3
0
 def load(self, file_path):
     if isinstance(file_path, str):
         file_path = Path(file_path)
     if not file_path.exists():
         logger.error(f"no such file {file_path} exists")
         sys.exit(1)
     logger.debug(f"loading the model from {file_path}")
     to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu"
     self.states = torch.load(file_path, map_location=to_device)
Beispiel #4
0
 def unit_train(self, data):
     xs, ys, frame_lens, label_lens, filenames, _ = data
     try:
         batch_size = xs.size(0)
         if self.use_cuda:
             xs = xs.cuda(non_blocking=True)
         ys_hat, frame_lens = self.model(xs, frame_lens)
         if frame_lens.lt(2 * label_lens).nonzero().numel():
             logger.debug(
                 "the batch includes a data with frame_lens < 2*label_lens, so skipped"
             )
             return None
         if self.fp16:
             ys_hat = ys_hat.float()
         ys_hat = ys_hat.transpose(0, 1).contiguous()  # TxNxH
         #torch.set_printoptions(threshold=5000000)
         #print(ys_hat.shape, frame_lens, ys.shape, label_lens)
         #print(onehot2int(ys_hat).squeeze(), ys)
         d = frame_lens.float()
         #d = frame_lens.sum().float()
         if self.use_cuda:
             d = d.cuda()
         loss = (self.loss(ys_hat, ys, frame_lens, label_lens) / d).mean()
         #loss = self.loss(ys_hat, ys, frame_lens, label_lens).div_(d)
         #loss = self.loss(ys_hat, ys, frame_lens, label_lens)
         if torch.isnan(loss) or loss.item() == float(
                 "inf") or loss.item() == -float("inf"):
             logger.warning(
                 "received an nan/inf loss: probably frame_lens < label_lens or the learning rate is too high"
             )
             #loss.mul_(0.)
             return None
         loss_value = loss.item()
         self.optimizer.zero_grad()
         if self.fp16:
             #self.optimizer.backward(loss)
             #self.optimizer.clip_master_grads(self.max_norm)
             with self.optimizer.scale_loss(loss) as scaled_loss:
                 scaled_loss.backward()
         else:
             loss.backward()
             nn.utils.clip_grad_norm_(self.model.parameters(),
                                      self.max_norm)
         #if is_distributed():
         #    self.average_gradients()
         self.optimizer.step()
         if self.use_cuda:
             torch.cuda.synchronize()
         del loss
         return loss_value
     except Exception as e:
         print(e)
         print(filenames, frame_lens, label_lens)
         raise
Beispiel #5
0
    def __init__(self,
                 model,
                 use_cuda=False,
                 continue_from=None,
                 verbose=False,
                 *args,
                 **kwargs):
        assert continue_from is not None
        self.use_cuda = use_cuda
        self.verbose = verbose

        # load from args
        self.model = model
        if self.use_cuda:
            logger.debug("using cuda")
            self.model.cuda()

        self.load(continue_from)

        # prepare kaldi latgen decoder
        self.decoder = LatGenCTCDecoder()
Beispiel #6
0
    def train_epoch(self, data_loader):
        self.model.train()
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)

        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)

        def plot_scalar(i, loss, title="train"):
            #if self.lr_scheduler is not None:
            #    self.lr_scheduler.step()
            x = self.epoch + i / len(data_loader)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'loss',
                }
                logger.visdom.add_point(title=title, x=x, y=loss, **opts)
            if logger.tensorboard is not None:
                logger.tensorboard.add_graph(self.model, xs)
                xs_img = tvu.make_grid(xs[0, 0],
                                       normalize=True,
                                       scale_each=True)
                logger.tensorboard.add_image('xs', x, xs_img)
                ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1),
                                           normalize=True,
                                           scale_each=True)
                logger.tensorboard.add_image('ys_hat', x, ys_hat_img)
                logger.tensorboard.add_scalars(title, x, {
                    'loss': loss,
                })

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        logger.debug(
            f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)
        ckpts = iter(len(data_loader) * np.arange(0.1, 1.1, 0.1))
        ckpt = next(ckpts)
        self.train_loop_before_hook()
        # count the number of supervised batches seen in this epoch
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training",
                 ncols=p.NCOLS)
        for i, (data) in t:
            loss_value = self.unit_train(data)
            if loss_value is not None:
                meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)
            if i > ckpt:
                plot_scalar(i, meter_loss.value()[0])
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
                ckpt = next(ckpts)
            #input("press key to continue")

        plot_scalar(i, meter_loss.value()[0])
        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)
        self.train_loop_after_hook()
Beispiel #7
0
    def __init__(self,
                 model,
                 amp_handle=None,
                 init_lr=1e-2,
                 max_norm=100,
                 use_cuda=False,
                 fp16=False,
                 log_dir='logs',
                 model_prefix='model',
                 checkpoint=False,
                 continue_from=None,
                 opt_type=None,
                 *args,
                 **kwargs):
        if fp16:
            import apex.parallel
            from apex import amp
            if not use_cuda:
                raise RuntimeError
        self.amp_handle = amp_handle

        # training parameters
        self.init_lr = init_lr
        self.max_norm = max_norm
        self.use_cuda = use_cuda
        self.fp16 = fp16
        self.log_dir = log_dir
        self.model_prefix = model_prefix
        self.checkpoint = checkpoint
        self.opt_type = opt_type
        self.epoch = 0
        self.states = None

        # load from pre-trained model if needed
        if continue_from is not None:
            self.load(continue_from)

        # setup model
        self.model = model
        if self.use_cuda:
            logger.debug("using cuda")
            self.model.cuda()

        # setup loss
        #self.loss = nn.CTCLoss(blank=0, reduction='none')
        self.loss = wp.CTCLoss(blank=0, length_average=True)

        # setup optimizer
        if opt_type is None:
            # for test only
            self.optimizer = None
            self.lr_scheduler = None
        else:
            assert opt_type in OPTIMIZER_TYPES
            parameters = self.model.parameters()
            if opt_type == "sgdr":
                logger.debug("using SGDR")
                self.optimizer = torch.optim.SGD(parameters,
                                                 lr=self.init_lr,
                                                 momentum=0.9,
                                                 weight_decay=5e-4)
                #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)
                self.lr_scheduler = CosineAnnealingWithRestartsLR(
                    self.optimizer, T_max=2, T_mult=2)
            elif opt_type == "adamwr":
                logger.debug("using AdamWR")
                self.optimizer = torch.optim.Adam(parameters,
                                                  lr=self.init_lr,
                                                  betas=(0.9, 0.999),
                                                  eps=1e-8,
                                                  weight_decay=5e-4)
                self.lr_scheduler = CosineAnnealingWithRestartsLR(
                    self.optimizer, T_max=2, T_mult=2)
            elif opt_type == "adam":
                logger.debug("using Adam")
                self.optimizer = torch.optim.Adam(parameters,
                                                  lr=self.init_lr,
                                                  betas=(0.9, 0.999),
                                                  eps=1e-8,
                                                  weight_decay=5e-4)
                self.lr_scheduler = None
            elif opt_type == "rmsprop":
                logger.debug("using RMSprop")
                self.optimizer = torch.optim.RMSprop(parameters,
                                                     lr=self.init_lr,
                                                     alpha=0.95,
                                                     eps=1e-8,
                                                     weight_decay=5e-4,
                                                     centered=True)
                self.lr_scheduler = None

        # setup decoder for test
        self.decoder = LatGenCTCDecoder()
        self.labeler = self.decoder.labeler

        # FP16 and distributed after load
        if self.fp16:
            #self.model = network_to_half(self.model)
            #self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.)
            self.optimizer = self.amp_handle.wrap_optimizer(self.optimizer)

        if is_distributed():
            if self.use_cuda:
                local_rank = torch.cuda.current_device()
                if fp16:
                    self.model = apex.parallel.DistributedDataParallel(
                        self.model)
                else:
                    self.model = nn.parallel.DistributedDataParallel(
                        self.model,
                        device_ids=[local_rank],
                        output_device=local_rank)
            else:
                self.model = nn.parallel.DistributedDataParallel(self.model)

        if self.states is not None:
            self.restore_state()
Beispiel #8
0
 def train_loop_before_hook(self):
     self.tfr_scheduler.step()
     logger.debug(f"current tfr = {self.model.tfr:.3e}")
Beispiel #9
0
    def train_epoch(self, data_loader):
        self.model.train()
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)
        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
            logger.debug(
                f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)

        ckpt_step = 0.1
        ckpts = iter(
            len(data_loader) * np.arange(ckpt_step, 1 + ckpt_step, ckpt_step))

        def plot_graphs(loss, data_iter=0, title="train", stats=False):
            #if self.lr_scheduler is not None:
            #    self.lr_scheduler.step()
            x = self.epoch + data_iter / len(data_loader)
            self.global_step = int(x / ckpt_step)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'loss',
                }
                logger.visdom.add_point(title=title, x=x, y=loss, **opts)
            if logger.tensorboard is not None:
                #logger.tensorboard.add_graph(self.model, xs)
                #xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True)
                #logger.tensorboard.add_image('xs', self.global_step, xs_img)
                #ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True)
                #logger.tensorboard.add_image('ys_hat', self.global_step, ys_hat_img)
                logger.tensorboard.add_scalars(title, self.global_step, {
                    'loss': loss,
                })
                if stats:
                    for name, param in self.model.named_parameters():
                        logger.tensorboard.add_histogram(
                            name, self.global_step,
                            param.clone().cpu().data.numpy())

        self.train_loop_before_hook()
        ckpt = next(ckpts)
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training",
                 ncols=p.NCOLS)
        for i, (data) in t:
            loss_value = self.unit_train(data)
            if loss_value is not None:
                meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)
            if i > ckpt:
                plot_graphs(meter_loss.value()[0], i)
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
                    self.train_loop_checkpoint_hook()
                ckpt = next(ckpts)

        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)
        plot_graphs(meter_loss.value()[0], stats=True)
        self.train_loop_after_hook()
Beispiel #10
0
    def __init__(self,
                 model,
                 init_lr=1e-4,
                 max_norm=400,
                 use_cuda=False,
                 fp16=False,
                 log_dir='logs',
                 model_prefix='model',
                 checkpoint=False,
                 continue_from=None,
                 opt_type="sgdr",
                 *args,
                 **kwargs):
        if fp16:
            if not use_cuda:
                raise RuntimeError

        # training parameters
        self.init_lr = init_lr
        self.max_norm = max_norm
        self.use_cuda = use_cuda
        self.fp16 = fp16
        self.log_dir = log_dir
        self.model_prefix = model_prefix
        self.checkpoint = checkpoint
        self.epoch = 0

        # prepare visdom
        if logger.visdom is not None:
            logger.visdom.add_plot(title=f'train',
                                   xlabel='epoch',
                                   ylabel='loss')
            logger.visdom.add_plot(title=f'validate',
                                   xlabel='epoch',
                                   ylabel='LER')

        # setup model
        self.model = model
        if self.use_cuda:
            logger.debug("using cuda")
            self.model.cuda()

        # setup loss
        self.loss = CTCLoss(blank=0, size_average=True, length_average=True)

        # setup optimizer
        assert opt_type in OPTIMIZER_TYPES
        parameters = self.model.parameters()
        if opt_type == "sgd":
            logger.debug("using SGD")
            self.optimizer = torch.optim.SGD(parameters,
                                             lr=self.init_lr,
                                             momentum=0.9)
            self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=5)
        elif opt_type == "sgdr":
            logger.debug("using SGDR")
            self.optimizer = torch.optim.SGD(parameters,
                                             lr=self.init_lr,
                                             momentum=0.9)
            #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)
            self.lr_scheduler = CosineAnnealingWithRestartsLR(self.optimizer,
                                                              T_max=5,
                                                              T_mult=2)
        elif opt_type == "adam":
            logger.debug("using AdamW")
            self.optimizer = torch.optim.Adam(parameters,
                                              lr=self.init_lr,
                                              betas=(0.9, 0.999),
                                              eps=1e-8,
                                              weight_decay=0.0005,
                                              l2_reg=False)
            self.lr_scheduler = None

        # setup decoder for test
        self.decoder = LatGenCTCDecoder()

        # load from pre-trained model if needed
        if continue_from is not None:
            self.load(continue_from)

        # FP16 and distributed after load
        if self.fp16:
            self.model = network_to_half(self.model)
            self.optimizer = FP16_Optimizer(self.optimizer,
                                            static_loss_scale=128.)

        if is_distributed():
            if self.use_cuda:
                local_rank = torch.cuda.current_device()
                if fp16:
                    self.model = apex.parallel.DistributedDataParallel(
                        self.model)
                else:
                    self.model = nn.parallel.DistributedDataParallel(
                        self.model,
                        device_ids=[local_rank],
                        output_device=local_rank)
            else:
                self.model = nn.parallel.DistributedDataParallel(self.model)
Beispiel #11
0
    def train_epoch(self, data_loader):
        self.model.train()
        num_ckpt = int(np.ceil(len(data_loader) / 10))
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)
        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
            logger.debug(f"current lr = {self.lr_scheduler.get_lr()}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)

        # count the number of supervised batches seen in this epoch
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training")
        for i, (data) in t:
            loss_value = self.unit_train(data)
            meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)

            if 0 < i < len(data_loader) and i % num_ckpt == 0:
                if not is_distributed() or (is_distributed()
                                            and dist.get_rank() == 0):
                    title = "train"
                    x = self.epoch + i / len(data_loader)
                    if logger.visdom is not None:
                        logger.visdom.add_point(title=title,
                                                x=x,
                                                y=meter_loss.value()[0])
                    if logger.tensorboard is not None:
                        logger.tensorboard.add_graph(self.model, xs)
                        xs_img = tvu.make_grid(xs[0, 0],
                                               normalize=True,
                                               scale_each=True)
                        logger.tensorboard.add_image('xs', x, xs_img)
                        ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1),
                                                   normalize=True,
                                                   scale_each=True)
                        logger.tensorboard.add_image('ys_hat', x, ys_hat_img)
                        logger.tensorboard.add_scalars(
                            title, x, {
                                'loss': meter_loss.value()[0],
                            })
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
            #input("press key to continue")

        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)