Exemplo n.º 1
0
class Trainer:
    def __init__(self,
                 model,
                 amp_handle=None,
                 init_lr=1e-2,
                 max_norm=100,
                 use_cuda=False,
                 fp16=False,
                 log_dir='logs',
                 model_prefix='model',
                 checkpoint=False,
                 continue_from=None,
                 opt_type=None,
                 *args,
                 **kwargs):
        if fp16:
            import apex.parallel
            from apex import amp
            if not use_cuda:
                raise RuntimeError
        self.amp_handle = amp_handle

        # training parameters
        self.init_lr = init_lr
        self.max_norm = max_norm
        self.use_cuda = use_cuda
        self.fp16 = fp16
        self.log_dir = log_dir
        self.model_prefix = model_prefix
        self.checkpoint = checkpoint
        self.opt_type = opt_type
        self.epoch = 0
        self.states = None

        # load from pre-trained model if needed
        if continue_from is not None:
            self.load(continue_from)

        # setup model
        self.model = model
        if self.use_cuda:
            logger.debug("using cuda")
            self.model.cuda()

        # setup loss
        #self.loss = nn.CTCLoss(blank=0, reduction='none')
        self.loss = wp.CTCLoss(blank=0, length_average=True)

        # setup optimizer
        if opt_type is None:
            # for test only
            self.optimizer = None
            self.lr_scheduler = None
        else:
            assert opt_type in OPTIMIZER_TYPES
            parameters = self.model.parameters()
            if opt_type == "sgdr":
                logger.debug("using SGDR")
                self.optimizer = torch.optim.SGD(parameters,
                                                 lr=self.init_lr,
                                                 momentum=0.9,
                                                 weight_decay=5e-4)
                #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)
                self.lr_scheduler = CosineAnnealingWithRestartsLR(
                    self.optimizer, T_max=2, T_mult=2)
            elif opt_type == "adamwr":
                logger.debug("using AdamWR")
                self.optimizer = torch.optim.Adam(parameters,
                                                  lr=self.init_lr,
                                                  betas=(0.9, 0.999),
                                                  eps=1e-8,
                                                  weight_decay=5e-4)
                self.lr_scheduler = CosineAnnealingWithRestartsLR(
                    self.optimizer, T_max=2, T_mult=2)
            elif opt_type == "adam":
                logger.debug("using Adam")
                self.optimizer = torch.optim.Adam(parameters,
                                                  lr=self.init_lr,
                                                  betas=(0.9, 0.999),
                                                  eps=1e-8,
                                                  weight_decay=5e-4)
                self.lr_scheduler = None
            elif opt_type == "rmsprop":
                logger.debug("using RMSprop")
                self.optimizer = torch.optim.RMSprop(parameters,
                                                     lr=self.init_lr,
                                                     alpha=0.95,
                                                     eps=1e-8,
                                                     weight_decay=5e-4,
                                                     centered=True)
                self.lr_scheduler = None

        # setup decoder for test
        self.decoder = LatGenCTCDecoder()
        self.labeler = self.decoder.labeler

        # FP16 and distributed after load
        if self.fp16:
            #self.model = network_to_half(self.model)
            #self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.)
            self.optimizer = self.amp_handle.wrap_optimizer(self.optimizer)

        if is_distributed():
            if self.use_cuda:
                local_rank = torch.cuda.current_device()
                if fp16:
                    self.model = apex.parallel.DistributedDataParallel(
                        self.model)
                else:
                    self.model = nn.parallel.DistributedDataParallel(
                        self.model,
                        device_ids=[local_rank],
                        output_device=local_rank)
            else:
                self.model = nn.parallel.DistributedDataParallel(self.model)

        if self.states is not None:
            self.restore_state()

    def __get_model_name(self, desc):
        return str(get_model_file_path(self.log_dir, self.model_prefix, desc))

    def __remove_ckpt_files(self, epoch):
        for ckpt in Path(self.log_dir).rglob(f"*_epoch_{epoch:03d}_ckpt_*"):
            ckpt.unlink()

    def train_loop_before_hook(self):
        pass

    def train_loop_after_hook(self):
        pass

    def unit_train(self, data):
        raise NotImplementedError

    #def average_gradients(self):
    #    if not is_distributed():
    #        return
    #    size = float(dist.get_world_size())
    #    for param in self.model.parameters():
    #        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, async_op=True)
    #        param.grad.data /= size

    def train_epoch(self, data_loader):
        self.model.train()
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)

        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)

        def plot_scalar(i, loss, title="train"):
            #if self.lr_scheduler is not None:
            #    self.lr_scheduler.step()
            x = self.epoch + i / len(data_loader)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'loss',
                }
                logger.visdom.add_point(title=title, x=x, y=loss, **opts)
            if logger.tensorboard is not None:
                logger.tensorboard.add_graph(self.model, xs)
                xs_img = tvu.make_grid(xs[0, 0],
                                       normalize=True,
                                       scale_each=True)
                logger.tensorboard.add_image('xs', x, xs_img)
                ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1),
                                           normalize=True,
                                           scale_each=True)
                logger.tensorboard.add_image('ys_hat', x, ys_hat_img)
                logger.tensorboard.add_scalars(title, x, {
                    'loss': loss,
                })

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        logger.debug(
            f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)
        ckpts = iter(len(data_loader) * np.arange(0.1, 1.1, 0.1))
        ckpt = next(ckpts)
        self.train_loop_before_hook()
        # count the number of supervised batches seen in this epoch
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training",
                 ncols=p.NCOLS)
        for i, (data) in t:
            loss_value = self.unit_train(data)
            if loss_value is not None:
                meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)
            if i > ckpt:
                plot_scalar(i, meter_loss.value()[0])
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
                ckpt = next(ckpts)
            #input("press key to continue")

        plot_scalar(i, meter_loss.value()[0])
        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)
        self.train_loop_after_hook()

    def unit_validate(self, data):
        raise NotImplementedError

    def validate(self, data_loader):
        "validate with label error rate by the edit distance between hyps and refs"
        self.model.eval()
        with torch.no_grad():
            N, D = 0, 0
            t = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     desc="validating",
                     ncols=p.NCOLS)
            for i, (data) in t:
                hyps, refs = self.unit_validate(data)
                # calculate ler
                N += self.edit_distance(refs, hyps)
                D += sum(len(r) for r in refs)
                ler = N * 100. / D
                t.set_description(f"validating (LER: {ler:.2f} %)")
                t.refresh()
            logger.info(
                f"validating at epoch {self.epoch:03d}: LER {ler:.2f} %")

            title = f"validate"
            x = self.epoch - 1 + i / len(data_loader)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'LER',
                }
                logger.visdom.add_point(title=title, x=x, y=ler, **opts)
            if logger.tensorboard is not None:
                logger.tensorboard.add_scalars(title, x, {
                    'LER': ler,
                })

    def unit_test(self, data):
        raise NotImplementedError

    def test(self, data_loader):
        "test with word error rate by the edit distance between hyps and refs"
        self.model.eval()
        with torch.no_grad():
            N, D = 0, 0
            t = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     desc="testing",
                     ncols=p.NCOLS)
            for i, (data) in t:
                hyps, refs = self.unit_test(data)
                # calculate wer
                N += self.edit_distance(refs, hyps)
                D += sum(len(r) for r in refs)
                wer = N * 100. / D
                t.set_description(f"testing (WER: {wer:.2f} %)")
                t.refresh()
            logger.info(f"testing at epoch {self.epoch:03d}: WER {wer:.2f} %")

    def edit_distance(self, refs, hyps):
        assert len(refs) == len(hyps)
        n = 0
        for ref, hyp in zip(refs, hyps):
            r = [chr(c) for c in ref]
            h = [chr(c) for c in hyp]
            n += Lev.distance(''.join(r), ''.join(h))
        return n

    def target_to_loglikes(self, ys, label_lens):
        max_len = max(label_lens.tolist())
        num_classes = self.labeler.get_num_labels()
        ys_hat = [
            torch.cat((torch.zeros(1).int(), ys[s:s + l],
                       torch.zeros(max_len - l).int()))
            for s, l in zip([0] + label_lens[:-1].cumsum(0).tolist(),
                            label_lens.tolist())
        ]
        ys_hat = [
            int2onehot(torch.IntTensor(z), num_classes, floor=1e-3)
            for z in ys_hat
        ]
        ys_hat = torch.stack(ys_hat)
        ys_hat = torch.log(ys_hat)
        return ys_hat

    def save_hook(self):
        pass

    def save(self, file_path, **kwargs):
        Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True)
        logger.debug(f"saving the model to {file_path}")

        if self.states is None:
            self.states = dict()
        self.states.update(kwargs)
        self.states["epoch"] = self.epoch
        self.states["opt_type"] = self.opt_type
        if is_distributed():
            model_state_dict = self.model.state_dict()
            strip_prefix = 9 if self.fp16 else 7
            # remove "module.1." prefix from keys
            self.states["model"] = {
                k[strip_prefix:]: v
                for k, v in model_state_dict.items()
            }
        else:
            self.states["model"] = self.model.state_dict()
        self.states["optimizer"] = self.optimizer.state_dict()
        if self.lr_scheduler is not None:
            self.states["lr_scheduler"] = self.lr_scheduler.state_dict()

        self.save_hook()
        torch.save(self.states, file_path)

    def load(self, file_path):
        if isinstance(file_path, str):
            file_path = Path(file_path)
        if not file_path.exists():
            logger.error(f"no such file {file_path} exists")
            sys.exit(1)
        logger.debug(f"loading the model from {file_path}")
        to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu"
        self.states = torch.load(file_path, map_location=to_device)

    def restore_state(self):
        self.epoch = self.states["epoch"]
        if is_distributed():
            self.model.load_state_dict(
                {f"module.{k}": v
                 for k, v in self.states["model"].items()})
        else:
            self.model.load_state_dict(self.states["model"])
        if "opt_type" in self.states and self.opt_type == self.states[
                "opt_type"]:
            self.optimizer.load_state_dict(self.states["optimizer"])
        if self.lr_scheduler is not None and "lr_scheduler" in self.states:
            self.lr_scheduler.load_state_dict(self.states["lr_scheduler"])
Exemplo n.º 2
0
class Trainer:
    def __init__(self,
                 model,
                 init_lr=1e-4,
                 max_norm=400,
                 use_cuda=False,
                 fp16=False,
                 log_dir='logs',
                 model_prefix='model',
                 checkpoint=False,
                 continue_from=None,
                 opt_type="sgdr",
                 *args,
                 **kwargs):
        if fp16:
            if not use_cuda:
                raise RuntimeError

        # training parameters
        self.init_lr = init_lr
        self.max_norm = max_norm
        self.use_cuda = use_cuda
        self.fp16 = fp16
        self.log_dir = log_dir
        self.model_prefix = model_prefix
        self.checkpoint = checkpoint
        self.epoch = 0

        # prepare visdom
        if logger.visdom is not None:
            logger.visdom.add_plot(title=f'train',
                                   xlabel='epoch',
                                   ylabel='loss')
            logger.visdom.add_plot(title=f'validate',
                                   xlabel='epoch',
                                   ylabel='LER')

        # setup model
        self.model = model
        if self.use_cuda:
            logger.debug("using cuda")
            self.model.cuda()

        # setup loss
        self.loss = CTCLoss(blank=0, size_average=True, length_average=True)

        # setup optimizer
        assert opt_type in OPTIMIZER_TYPES
        parameters = self.model.parameters()
        if opt_type == "sgd":
            logger.debug("using SGD")
            self.optimizer = torch.optim.SGD(parameters,
                                             lr=self.init_lr,
                                             momentum=0.9)
            self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=5)
        elif opt_type == "sgdr":
            logger.debug("using SGDR")
            self.optimizer = torch.optim.SGD(parameters,
                                             lr=self.init_lr,
                                             momentum=0.9)
            #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)
            self.lr_scheduler = CosineAnnealingWithRestartsLR(self.optimizer,
                                                              T_max=5,
                                                              T_mult=2)
        elif opt_type == "adam":
            logger.debug("using AdamW")
            self.optimizer = torch.optim.Adam(parameters,
                                              lr=self.init_lr,
                                              betas=(0.9, 0.999),
                                              eps=1e-8,
                                              weight_decay=0.0005,
                                              l2_reg=False)
            self.lr_scheduler = None

        # setup decoder for test
        self.decoder = LatGenCTCDecoder()

        # load from pre-trained model if needed
        if continue_from is not None:
            self.load(continue_from)

        # FP16 and distributed after load
        if self.fp16:
            self.model = network_to_half(self.model)
            self.optimizer = FP16_Optimizer(self.optimizer,
                                            static_loss_scale=128.)

        if is_distributed():
            if self.use_cuda:
                local_rank = torch.cuda.current_device()
                if fp16:
                    self.model = apex.parallel.DistributedDataParallel(
                        self.model)
                else:
                    self.model = nn.parallel.DistributedDataParallel(
                        self.model,
                        device_ids=[local_rank],
                        output_device=local_rank)
            else:
                self.model = nn.parallel.DistributedDataParallel(self.model)

    def __get_model_name(self, desc):
        return str(get_model_file_path(self.log_dir, self.model_prefix, desc))

    def __remove_ckpt_files(self, epoch):
        for ckpt in Path(self.log_dir).rglob(f"*_epoch_{epoch:03d}_ckpt_*"):
            ckpt.unlink()

    def unit_train(self, data):
        raise NotImplementedError

    def train_epoch(self, data_loader):
        self.model.train()
        num_ckpt = int(np.ceil(len(data_loader) / 10))
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)
        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
            logger.debug(f"current lr = {self.lr_scheduler.get_lr()}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)

        # count the number of supervised batches seen in this epoch
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training")
        for i, (data) in t:
            loss_value = self.unit_train(data)
            meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)

            if 0 < i < len(data_loader) and i % num_ckpt == 0:
                if not is_distributed() or (is_distributed()
                                            and dist.get_rank() == 0):
                    title = "train"
                    x = self.epoch + i / len(data_loader)
                    if logger.visdom is not None:
                        logger.visdom.add_point(title=title,
                                                x=x,
                                                y=meter_loss.value()[0])
                    if logger.tensorboard is not None:
                        logger.tensorboard.add_graph(self.model, xs)
                        xs_img = tvu.make_grid(xs[0, 0],
                                               normalize=True,
                                               scale_each=True)
                        logger.tensorboard.add_image('xs', x, xs_img)
                        ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1),
                                                   normalize=True,
                                                   scale_each=True)
                        logger.tensorboard.add_image('ys_hat', x, ys_hat_img)
                        logger.tensorboard.add_scalars(
                            title, x, {
                                'loss': meter_loss.value()[0],
                            })
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
            #input("press key to continue")

        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)

    def unit_validate(self, data):
        raise NotImplementedError

    def validate(self, data_loader):
        "validate with label error rate by the edit distance between hyps and refs"
        self.model.eval()
        with torch.no_grad():
            N, D = 0, 0
            t = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     desc="validating")
            for i, (data) in t:
                hyps, refs = self.unit_validate(data)
                # calculate ler
                N += self.edit_distance(refs, hyps)
                D += sum(len(r) for r in refs)
                ler = N * 100. / D
                t.set_description(f"validating (LER: {ler:.2f} %)")
                t.refresh()
            logger.info(
                f"validating at epoch {self.epoch:03d}: LER {ler:.2f} %")

            if not is_distributed() or (is_distributed()
                                        and dist.get_rank() == 0):
                title = f"validate"
                x = self.epoch - 1 + i / len(data_loader)
                if logger.visdom is not None:
                    logger.visdom.add_point(title=title, x=x, y=ler)
                if logger.tensorboard is not None:
                    logger.tensorboard.add_scalars(title, x, {
                        'LER': ler,
                    })

    def unit_test(self, data):
        raise NotImplementedError

    def test(self, data_loader):
        "test with word error rate by the edit distance between hyps and refs"
        self.model.eval()
        with torch.no_grad():
            N, D = 0, 0
            t = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     desc="testing")
            for i, (data) in t:
                hyps, refs = self.unit_test(data)
                # calculate wer
                N += self.edit_distance(refs, hyps)
                D += sum(len(r) for r in refs)
                wer = N * 100. / D
                t.set_description(f"testing (WER: {wer:.2f} %)")
                t.refresh()
            logger.info(f"testing at epoch {self.epoch:03d}: WER {wer:.2f} %")

    def edit_distance(self, refs, hyps):
        assert len(refs) == len(hyps)
        n = 0
        for ref, hyp in zip(refs, hyps):
            r = [chr(c) for c in ref]
            h = [chr(c) for c in hyp]
            n += Lev.distance(''.join(r), ''.join(h))
        return n

    def save(self, file_path, **kwargs):
        Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True)
        logger.info(f"saving the model to {file_path}")
        states = kwargs
        states["epoch"] = self.epoch
        if is_distributed():
            model_state_dict = self.model.state_dict()
            strip_prefix = 9 if self.fp16 else 7
            # remove "module.1." prefix from keys
            states["model"] = {
                k[strip_prefix:]: v
                for k, v in model_state_dict.items()
            }
        else:
            states["model"] = self.model.state_dict()
        states["optimizer"] = self.optimizer.state_dict()
        states["lr_scheduler"] = self.lr_scheduler.state_dict()
        torch.save(states, file_path)

    def load(self, file_path):
        if isinstance(file_path, str):
            file_path = Path(file_path)
        if not file_path.exists():
            logger.error(f"no such file {file_path} exists")
            sys.exit(1)
        logger.info(f"loading the model from {file_path}")
        to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu"
        states = torch.load(file_path, map_location=to_device)
        self.epoch = states["epoch"]
        self.model.load_state_dict(states["model"])
        self.optimizer.load_state_dict(states["optimizer"])
        self.lr_scheduler.load_state_dict(states["lr_scheduler"])