Exemplo n.º 1
0
def train(
    model,
    epochs=110,
    batch_size=128,
    train_index_path=TRAIN_PATH,
    dev_index_path=DEV_PATH,
    labels_path=LABEL_PATH,
    learning_rate=0.6,
    momentum=0.8,
    max_grad_norm=0.2,
    weight_decay=0,
):
    train_dataset = data.MASRDataset(train_index_path, labels_path)
    batchs = (len(train_dataset) + batch_size - 1) // batch_size
    dev_dataset = data.MASRDataset(dev_index_path, labels_path)
    train_dataloader = data.MASRDataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=0)
    train_dataloader_shuffle = data.MASRDataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   num_workers=0,
                                                   shuffle=True)
    dev_dataloader = data.MASRDataLoader(dev_dataset,
                                         batch_size=batch_size,
                                         num_workers=0)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(
        parameters,
        lr=learning_rate,
        momentum=momentum,
        nesterov=True,
        weight_decay=weight_decay,
    )
    ctcloss = CTCLoss(size_average=True)
    # lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)

    gstep = 0
    for epoch in range(epochs):
        epoch_loss = 0
        if epoch > 0:
            train_dataloader = train_dataloader_shuffle
        # lr_sched.step()
        for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
            x = x.cuda()
            out, out_lens = model(x, x_lens)
            out = out.transpose(0, 1).transpose(0, 2)
            loss = ctcloss(out, y, out_lens, y_lens)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            epoch_loss += loss.item()
            gstep += 1
            print("[{}/{}][{}/{}]\tLoss = {}".format(epoch + 1, epochs, i,
                                                     int(batchs), loss.item()))
        epoch_loss = epoch_loss / batchs
        cer = eval(model, dev_dataloader)
        print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
        if (epoch + 1) % 5 == 0:
            torch.save(model, "pretrained/model_{}.pth".format(epoch))
Exemplo n.º 2
0
 def fit(
     self,
     train_index,
     dev_index,
     epochs=100,
     train_batch_size=64,
     lr=0.6,
     momentum=0.8,
     grad_clip=0.2,
     dev_batch_size=64,
     sorta_grad=True,
     tensorboard=True,
     quiet=False,
 ):
     self.to("cuda")
     self.train()
     if tensorboard:
         writer = SummaryWriter()
     optimizer = optim.SGD(self.parameters(), lr, momentum, nesterov=True)
     train_dataset = data.MASRDataset(train_index, self.vocabulary)
     train_loader_shuffle = data.MASRDataLoader(
         train_dataset, train_batch_size, shuffle=True, num_workers=16
     )
     if sorta_grad:
         train_loader_sort = data.MASRDataLoader(
             train_dataset, train_batch_size, shuffle=False, num_workers=16
         )
     train_steps = len(train_loader_shuffle)
     gstep = 0
     for epoch in range(epochs):
         avg_loss = 0
         if epoch == 0 and sorta_grad:
             train_loader = train_loader_sort
         else:
             train_loader = train_loader_shuffle
         for step, (x, y, x_lens, y_lens) in enumerate(train_loader):
             x = x.to("cuda")
             gstep += 1
             outputs = self.forward(x, x_lens)
             loss = self.loss(outputs, (y,y_lens))
             optimizer.zero_grad()
             loss.backward()
             nn.utils.clip_grad_norm_(self.parameters(), grad_clip)
             optimizer.step()
             avg_loss += loss.item()
             if not quiet:
                 print("[{}/{}][{}/{}]\tLoss = {}".format(epoch + 1, epochs, step + 1, train_steps, loss.item()))
             if tensorboard:
                 writer.add_scalar("loss/step", loss.item(), gstep)
         cer = self.test(dev_index, dev_batch_size)
         avg_loss /= train_steps
         if not quiet:
             print("Epoch {}\t CER = {}\t".format(epoch + 1, cer))
         if tensorboard:
             writer.add_scalar("cer/epoch", cer, epoch + 1)
             writer.add_scalar("loss/epoch", loss, epoch + 1)
         self.save("pretrained/{}_epoch_{}.pth".format(self.name, epoch + 1))
Exemplo n.º 3
0
def model_setup(args=None):

    test_dataset = data.MASRDataset(args.test_index_path,
                                    args.labels_path,
                                    args.mode,
                                    config=args)
    dataloader = data.MASRDataLoader(test_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.num_workers)

    model = GatedConv.load(args.pretrained_path)

    global decoder
    decoder = BeamCTCDecoder(
        dataloader.dataset.labels_str,
        alpha=0.8,
        beta=0.3,
        lm_path="/root/lm/zh_giga.no_cna_cmn.prune01244.klm",
        cutoff_top_n=40,
        cutoff_prob=1.0,
        beam_width=100,
        num_processes=args.num_workers,
        blank_index=0,
    )

    return model, dataloader
Exemplo n.º 4
0
 def test(self, test_index, batch_size=64):  # -> cer: float
     self.eval()
     test_dataset = data.MASRDataset(test_index, self.vocabulary)
     test_loader = data.MASRDataLoader(
         test_dataset, batch_size, shuffle=False, num_workers=16
     )
     test_steps = len(test_loader)
     cer = 0
     for inputs, targets in tqdm(test_loader, total=test_steps):
         x, x_lens = inputs
         x = x.to("cuda")
         outputs = self.forward(x, x_lens)
         texts = self.decode(*outputs)
         cer += self.cer(texts, *targets)
     cer /= test_steps
     self.train()
     return cer
Exemplo n.º 5
0
def train(
    model,
    epochs=100,
    batch_size=64,
    train_index_path="train.index",
    dev_index_path="dev.index",
    labels_path="labels.json",
    learning_rate=0.6,
    momentum=0.8,
    max_grad_norm=0.2,
    weight_decay=0,
):
    train_dataset = data.MASRDataset(train_index_path, labels_path)
    batchs = (len(train_dataset) + batch_size - 1) // batch_size
    dev_dataset = data.MASRDataset(dev_index_path, labels_path)
    train_dataloader = data.MASRDataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=8)
    train_dataloader_shuffle = data.MASRDataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   num_workers=8,
                                                   shuffle=True)
    dev_dataloader = data.MASRDataLoader(dev_dataset,
                                         batch_size=batch_size,
                                         num_workers=8)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(
        parameters,
        lr=learning_rate,
        momentum=momentum,
        nesterov=True,
        weight_decay=weight_decay,
    )
    ctcloss = CTCLoss(size_average=True)
    # lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)
    writer = tensorboard.SummaryWriter()
    gstep = 0
    for epoch in range(epochs):
        epoch_loss = 0
        if epoch > 0:
            train_dataloader = train_dataloader_shuffle
        # lr_sched.step()
        lr = get_lr(optimizer)
        writer.add_scalar("lr/epoch", lr, epoch)
        for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
            x = x.to("cuda")
            out, out_lens = model(x, x_lens)
            out = out.transpose(0, 1).transpose(0, 2)
            loss = ctcloss(out, y, out_lens, y_lens)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar("loss/step", loss.item(), gstep)
            gstep += 1
            print("[{}/{}][{}/{}]\tLoss = {}".format(epoch + 1, epochs, i,
                                                     int(batchs), loss.item()))
        epoch_loss = epoch_loss / batchs
        cer = eval(model, dev_dataloader)
        writer.add_scalar("loss/epoch", epoch_loss, epoch)
        writer.add_scalar("cer/epoch", cer, epoch)
        print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
        torch.save(model, "pretrained/model_{}.pth".format(epoch))
Exemplo n.º 6
0
            a = 0.3409090909090909  #0.341
            b = 0.8571428571428571  #0.857
        else:
            a = 0.11363636363636363  #0.3
            b = 0
        print(a, b)
        # results, targets = eval(model, dataloader, device, args.save_output, a, b)

        if "tg" in args.test_index_path:
            args.test_index_path = args.test_index_path.replace("tg", "ts")
            # args.save_output = args.save_output.replace("tg", "ts")
            a = 0.05  #0.11363636363636363
            b = 0
            print(args.test_index_path, args.save_output)
        else:
            args.test_index_path = args.test_index_path.replace("ts", "tg")
            # args.save_output = args.save_output.replace("ts", "tg")
            a = 0.05  #0.11363636363636363
            b = 0
            print(args.test_index_path, args.save_output)
        print(a, b)
        test_dataset = data.MASRDataset(args.test_index_path,
                                        args.labels_path,
                                        args.mode,
                                        config=args)
        dataloader = data.MASRDataLoader(test_dataset,
                                         batch_size=args.batch_size,
                                         num_workers=args.num_workers)
        results, targets = eval(model, dataloader, device, args.save_output, a,
                                b)
Exemplo n.º 7
0
def train(model,
          epochs=1000,
          batch_size=64,
          train_index_path="./dataset/train_index.json",
          dev_index_path="./dataset/dev_index.json",
          labels_path="./dataset/labels.json",
          learning_rate=0.6,
          momentum=0.8,
          max_grad_norm=0.2,
          weight_decay=0,
          config=None):

    train_dataset = data.MASRDataset(train_index_path,
                                     labels_path,
                                     config=config)
    batchs = (len(train_dataset) + batch_size - 1) // batch_size
    train_dataloader = data.MASRDataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=cpu_num)
    train_dataloader_shuffle = data.MASRDataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   num_workers=cpu_num,
                                                   shuffle=True,
                                                   pin_memory=True)

    dev_datasets, dev_dataloaders = [], []
    for _item in ["IOS", "Android", "Recorder"]:
        dev_datasets.append(
            data.MASRDataset(dev_index_path,
                             labels_path,
                             mode="dev",
                             config=config,
                             device_type=_item))
        dev_dataloaders.append(
            data.MASRDataLoader(dev_datasets[-1],
                                batch_size=batch_size,
                                num_workers=cpu_num,
                                pin_memory=True))

    if config.optim == "sgd":
        print("choose sgd.")
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            nesterov=True,
            weight_decay=weight_decay,
        )
    else:
        print("choose adamwr.")
        optimizer = AdamW(model.parameters(),
                          lr=learning_rate,
                          weight_decay=1e-4)
        if config.fp16:
            # Allow Amp to perform casts as required by the opt_level
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=config.fp16,
                                              loss_scale=args.loss_scale)
            scheduler = CyclicLRWithRestarts(
                optimizer,
                config.batch_size,
                epoch_size=len(train_dataloader.dataset),
                restart_period=5,
                t_mult=1.2,
                eta_on_restart_cb=ReduceMaxLROnRestart(ratio=config.wr_ratio),
                policy="cosine")
        else:
            scheduler = CyclicLRWithRestarts(
                optimizer,
                batch_size,
                epoch_size=len(train_dataloader.dataset),
                restart_period=5,
                t_mult=1.2,
                eta_on_restart_cb=ReduceMaxLROnRestart(ratio=config.wr_ratio),
                policy="cosine")

    ctcloss = CTCLoss()  #size_average=True
    decoder = GreedyDecoder(train_dataloader.dataset.labels_str)
    # lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)
    writer = tensorboard.SummaryWriter('./logs-O2/')
    gstep = 0
    start_epoch = 0
    best_cer = 1

    # optionally resume from a checkpoint
    resume = args.resume
    if resume is not None:
        if os.path.isfile(resume):
            print_log("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume, map_location="cpu")
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            gstep = checkpoint['gstep']
            best_cer = checkpoint['best_cer']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume, start_epoch))
        else:
            print("=> no checkpoint found at '{}'".format(resume))
    else:
        print("=> did not use any checkpoint")

    for epoch in range(start_epoch, epochs):
        epoch_loss = 0
        cer_tr = 0
        if epoch > 0:
            train_dataloader = train_dataloader_shuffle
        # lr_sched.step()
        lr = get_lr(optimizer)
        writer.add_scalar("lr/epoch", lr, epoch)
        if config.optim == "adamwr": scheduler.step()
        for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
            x = x.to(device)
            out, out_lens = model(x, x_lens)
            outs = out.transpose(0, 1).transpose(0, 2)
            loss = ctcloss(outs, y, out_lens, y_lens).to(device)
            loss = loss / x.size(0)
            optimizer.zero_grad()
            # 混合精度加速
            if config.fp16 is not None:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(),
                                     max_grad_norm)  #if config.optim == "sgd":
            optimizer.step()
            if config.optim == "adamwr": scheduler.batch_step()
            #loss_value = loss.item()

            # cer
            outs = F.softmax(out, 1)
            outs = outs.transpose(1, 2)
            ys = []
            offset = 0
            for y_len in y_lens:
                ys.append(y[offset:offset + y_len])
                offset += y_len
            out_strings, out_offsets = decoder.decode(outs, out_lens)
            y_strings = decoder.convert_to_strings(ys)
            for pred, truth in zip(out_strings, y_strings):
                trans, ref = pred[0], truth[0]
                cer_tr += decoder.cer(trans, ref) / float(len(ref))
            # loss
            epoch_loss += loss.item()
            writer.add_scalar("loss/step", loss.item(), gstep)
            writer.add_scalar("cer_tr/step", cer_tr / (batch_size * (i + 1)),
                              gstep)
            gstep += 1
            # display
            if i % 5 == 0:
                print("[{}/{}][{}/{}]\tLoss = {:.4f},\tCer = {:.4f}".format(
                    epoch + 1, epochs, i, int(batchs), loss.item(),
                    cer_tr / (batch_size * (i + 1))),
                      flush=True)

                if args.debug: break

        cer_tr /= len(train_dataloader.dataset)
        epoch_loss = epoch_loss / batchs

        cer_devs, loss_devs = [], []
        for dev_dataloader in dev_dataloaders:
            cer_dev, loss_dev = eval(model, dev_dataloader)
            cer_devs.append(cer_dev)
            loss_devs.append(loss_dev)

            if args.debug:
                cer_devs.extend([0, 0])
                loss_devs.extend([0, 0])
                break

        cer_dev = sum(cer_devs) / 3
        loss_dev = sum(loss_devs) / 3

        writer.add_scalar("loss/epoch", epoch_loss, epoch)
        # writer.add_scalar("loss_dev/epoch", loss_dev, epoch)
        writer.add_scalars(
            "loss_dev", {
                "loss_dev": loss_dev,
                "loss_ios": loss_devs[0],
                "loss_recorder": loss_devs[2],
                "loss_android": loss_devs[1]
            }, epoch)
        writer.add_scalar("cer_tr/epoch", cer_tr, epoch)
        writer.add_scalars(
            "cer_dev", {
                "cer_dev": cer_dev,
                "cer_ios": cer_devs[0],
                "cer_recorder": cer_devs[2],
                "cer_android": cer_devs[1]
            }, epoch)
        print(
            "Epoch {}: Loss= {:.4f}, Loss_dev= {:.4f}, CER_tr = {:.4f}, CER_dev = {:.4f}, CER_ios = {:.4f}"
            .format(epoch, epoch_loss, loss_dev, cer_tr, cer_dev, cer_devs[0]))
        if cer_dev <= best_cer:
            best_cer = cer_dev
            save_path = "pretrained/model_bestCer_{}_{:.4f}_{}.pth".format(
                config.optim, epoch + cer_dev,
                time.strftime("%m%d%H%M", time.localtime()))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'config': {
                        "vocabulary": model.vocabulary
                    },
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'gstep': gstep,
                    'best_cer': best_cer,
                }, save_path)
            # torch.save(model, )
            with open(
                    "{}_{:.4f}_{:.4f}_{:.4f}.info".format(
                        epoch, epoch_loss, cer_tr, cer_dev), "w") as _fw:
                _fw.write("")

        if args.debug: break
Exemplo n.º 8
0
def train(
    model,
    epochs=1000,
    batch_size=64,
    train_index_path="../data_aishell/train-sort.manifest",
    dev_index_path="../data_aishell/dev.manifest",
    labels_path="../data_aishell/labels.json",
    learning_rate=0.6,
    momentum=0.8,
    max_grad_norm=0.2,
    weight_decay=0,
):
    hvd.init()
    torch.manual_seed(1024)
    torch.cuda.set_device(hvd.local_rank())
    torch.cuda.manual_seed(1024)

    # dataset loader
    train_dataset = data.MASRDataset(train_index_path, labels_path)
    batchs = (len(train_dataset) + batch_size - 1) // batch_size

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_dataloader = data.MASRDataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           sampler=train_sampler)

    dev_dataset = data.MASRDataset(dev_index_path, labels_path)
    dev_sampler = torch.utils.data.distributed.DistributedSampler(
        dev_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    dev_dataloader = data.MASRDataLoader(dev_dataset,
                                         batch_size=batch_size,
                                         num_workers=1,
                                         sampler=dev_sampler)

    # optimizer
    parameters = model.parameters()
    optimizer = torch.optim.SGD(
        parameters,
        lr=learning_rate * hvd.size(),
        momentum=momentum,
        nesterov=True,
        weight_decay=weight_decay,
    )

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

    # Horovod: wrap optimizer with DistributedOptimizer.
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    ctcloss = nn.CTCLoss()

    # lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)
    writer = tensorboard.SummaryWriter()
    gstep = 0
    for epoch in range(epochs):
        epoch_loss = 0
        # lr_sched.step()
        lr = get_lr(optimizer)
        if hvd.rank() == 0:
            writer.add_scalar("lr/epoch", lr, epoch)
        for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
            x = x.to(device)
            out, out_lens = model(x, x_lens)
            out = out.transpose(0, 1).transpose(0, 2).log_softmax(2)
            loss = ctcloss(out, y, out_lens, y_lens)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            epoch_loss += loss.item()

            if hvd.rank() == 0:
                writer.add_scalar("loss/step", loss.item(), gstep)
                gstep += 1
                print("[{}/{}][{}/{}]\tLoss = {}".format(
                    epoch + 1, epochs, i, int(batchs), loss.item()))

        epoch_loss = epoch_loss / batchs
        cer = eval(model, dev_dataloader)
        writer.add_scalar("loss/epoch", epoch_loss, epoch)
        writer.add_scalar("cer/epoch", cer, epoch)
        print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
        torch.save(model.state_dict(), "pretrained/model_{}.pth".format(epoch))