Beispiel #1
0
def main():
    args = parse_args()
    cfg.resume = args.resume
    cfg.exp_name = args.exp
    cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size = args.batch_size
    cfg.lr = args.lr
    cfg.load_from = args.load_from
    cfg.save_excel = args.save_excel

    weights_dir = os.path.join(cfg.workdir, 'weights')
    check_and_mkdir(weights_dir)

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir + '/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train,
                              batch_size=cfg.batch_size,
                              num_workers=4,
                              shuffle=True,
                              pin_memory=True)
    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val,
                            batch_size=cfg.batch_size,
                            num_workers=4,
                            shuffle=False,
                            pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=40, dilations=[1, 2, 4, 8, 16])
    model = nn.DataParallel(model)
    model.cuda()
    model.train()

    # build loss
    loss_fn = nn.CTCLoss(blank=27)

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'),
                              strict=True)
        print("loading", cfg.workdir + '/weights/best.pth')
        cfg.load_from = cfg.workdir + '/weights/best.pth'

    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
Beispiel #2
0
def main():
    print('initial training...')
    print(
        f'work_dir:{cfg.workdir}, pretrained:{cfg.load_from}, batch_size:{cfg.batch_size} lr:{cfg.lr}, epochs:{cfg.epochs}'
    )
    args = parse_args()
    writer = SummaryWriter(log_dir=cfg.workdir + '/runs')

    # distributed training setting
    assert cfg.distributed
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group('nccl', init_method='env://')

    # build dataloader
    vctk_train = VCTK(cfg, 'train')
    train_sample = torch.utils.data.distributed.DistributedSampler(
        vctk_train,
        shuffle=True,
    )
    # train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)
    train_loader = DataLoader(vctk_train,
                              batch_size=cfg.batch_size,
                              sampler=train_sample,
                              num_workers=8,
                              pin_memory=True)

    vctk_val = VCTK(cfg, 'val')
    val_sample = torch.utils.data.distributed.DistributedSampler(
        vctk_val,
        shuffle=False,
    )
    # val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)
    val_loader = DataLoader(vctk_val,
                            batch_size=cfg.batch_size,
                            sampler=val_sample,
                            num_workers=8,
                            pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20).cuda()
    model = DDP(model, device_ids=[args.local_rank], broadcast_buffers=False)
    # model = nn.DataParallel(model)

    # build loss
    loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)

    # train
    train(args, train_loader, scheduler, model, loss_fn, val_loader, writer)
Beispiel #3
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, batch_size, seed, checkpoint_path):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    #trainset = Mel2SampOnehot(**data_config)
    trainset = DeepMels(**data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        total_loss = 0
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs
            y_pred = model(x)
            loss = criterion(y_pred, y)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus)[0]
            else:
                reduced_loss = loss.data[0]
            loss.backward()
            optimizer.step()

            total_loss += reduced_loss

            if (iteration % iters_per_checkpoint == 0):
                if rank == 0:
                    checkpoint_path = "{}/wavenet_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1
        print("epoch:{}, total epoch loss:{}".format(epoch, total_loss))
Beispiel #4
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']
        ), "Number of datasets must match number of nodes"

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = torch.load(args.checkpoint)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(
                self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(
                self.discriminator)
            self.logger.info('Created DistributedDataParallel')
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(
                self.discriminator).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        if args.checkpoint and args.load_optimizer:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])
            self.d_optimizer.load_state_dict(states['d_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item()

        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - discriminator right
        z = self.encoder(x)
        z_logits = self.discriminator(z)
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.args.d_lambda
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(),
                             self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z = self.encoder(x_aug)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)
        discriminator_wrong = -F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)

        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        self.loss_total.reset()

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'discriminator_state': self.discriminator.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
                'd_optimizer_state': self.d_optimizer.state_dict()
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
Beispiel #5
0
def train(model_directory, epochs, learning_rate, epochs_per_checkpoint,
          batch_size, seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()
    # model.upsample = torch.nn.Sequential() #replace the upsample step with no operation as we manually control samples
    # model.upsample.weight = None
    # model.upsample.bias = None

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Load checkpoint if one exists
    iteration = 0
    checkpoint_path = find_checkpoint(model_directory)
    if checkpoint_path is not None:
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    trainset = SimpleWaveLoader()
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=None,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):

            model.zero_grad()

            x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs

            y_pred = model(x)
            loss = criterion(y_pred, y)
            reduced_loss = loss.data.item()
            loss.backward()
            optimizer.step()

            #print out the loss, and save to a file
            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            with open(os.path.join(model_directory, 'loss_history.txt'),
                      'a') as f:
                f.write('%s\n' % str(reduced_loss))

            iteration += 1
            torch.cuda.empty_cache()

        if (epoch != 0 and epoch % epochs_per_checkpoint == 0):
            checkpoint_path = os.path.join(model_directory,
                                           'checkpoint_%d' % iteration)
            save_checkpoint(model, optimizer, learning_rate, iteration,
                            checkpoint_path)
Beispiel #6
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.work_root   = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir     = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from
    cfg.save_excel   = args.save_excel        
    
    if args.find_pattern == True:
        cfg.find_pattern_num   = 16
        cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])
        if int(cfg.find_pattern_shape[0] * cfg.find_pattern_shape[1]) <= cfg.find_score_threshold:
            exit()

    if args.skip_exist == True:
        if os.path.exists(cfg.workdir):
            exit()

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True)

    # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True)
    vctk_val = VCTK(cfg, 'val')
    if args.test_acc_cmodel == True:
        val_loader = DataLoader(vctk_val, batch_size=1, num_workers=4, shuffle=False, pin_memory=True)
    else:
        val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=40, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()


    name_list = list()
    para_list = list()
    for name, para in model.named_parameters():
        name_list.append(name)
        para_list.append(para)

    a = model.state_dict()
    for i, name in enumerate(name_list):
        if name.split(".")[-2] != "bn" \
            and name.split(".")[-2] != "bn2" \
            and name.split(".")[-2] != "bn3" \
            and name.split(".")[-1] != "bias":
            raw_w = para_list[i]
            nn.init.xavier_normal_(raw_w, gain=1.0)
            a[name] = raw_w
    model.load_state_dict(a)
    

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    if args.vis_pattern == True or args.vis_mask == True:
        cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True)
        print("loading", cfg.workdir + '/weights/best.pth')
        cfg.load_from = cfg.workdir + '/weights/best.pth'

    if args.test_acc == True:
        if os.path.exists(cfg.load_from):
            model.load_state_dict(torch.load(cfg.load_from), strict=True)
            print("loading", cfg.load_from)
        else:
            print("Error: model file not exists, ", cfg.load_from)
            exit()
    else:
        if os.path.exists(cfg.load_from):
            model.load_state_dict(torch.load(cfg.load_from), strict=True)
            print("loading", cfg.load_from)
            # Export the model
            print("exporting onnx ...")
            model.eval()
            batch_size = 1
            x = torch.randn(batch_size, 40, 720, requires_grad=True).cuda()
            torch.onnx.export(model.module,               # model being run
                            x,                         # model input (or a tuple for multiple inputs)
                            "wavenet.onnx",   # where to save the model (can be a file or file-like object)
                            export_params=True,        # store the trained parameter weights inside the model file
                            opset_version=10,          # the ONNX version to export the model to
                            do_constant_folding=True,  # whether to execute constant folding for optimization
                            input_names = ['input'],   # the model's input names
                            output_names = ['output'], # the model's output names
                            dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                            'output' : {0 : 'batch_size'}})

    if os.path.exists(args.load_from_h5):
        # model.load_state_dict(torch.load(args.load_from_h5), strict=True)
        print("loading", args.load_from_h5)
        model.train()
        model_dict = model.state_dict()
        print(model_dict.keys())
        #先将参数值numpy转换为tensor形式
        pretrained_dict = dd.io.load(args.load_from_h5)
        print(pretrained_dict.keys())
        new_pre_dict = {}
        for k,v in pretrained_dict.items():
            new_pre_dict[k] = torch.Tensor(v)
        #更新
        model_dict.update(new_pre_dict)
        #加载
        model.load_state_dict(model_dict)

    if args.find_pattern == True:

        # cfg.find_pattern_num   = 16
        # cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        # cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        # cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])

        # if cfg.find_pattern_shape[0] * cfg.find_pattern_shape[0] <= cfg.find_score_threshold:
        #     exit()

        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        a = model.state_dict()
        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" \
                and name.split(".")[-2] != "bn2" \
                and name.split(".")[-2] != "bn3" \
                and name.split(".")[-1] != "bias":
                raw_w = para_list[i]
                if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                    patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \
                                    = find_pattern_by_similarity(raw_w
                                        , cfg.find_pattern_num
                                        , cfg.find_pattern_shape
                                        , cfg.find_zero_threshold
                                        , cfg.find_score_threshold)

                    pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict \
                                    = pattern_curve_analyse(raw_w.shape
                                        , cfg.find_pattern_shape
                                        , patterns
                                        , pattern_match_num
                                        , pattern_coo_nnz
                                        , pattern_nnz
                                        , pattern_inner_nnz)
                                        
                    write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                                        , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para
                                        , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz
                                        , pattern_inner_nnz
                                        , pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict)

                    # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                    #                     , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para
                    #                     , all_nnzs.values(), all_patterns.values())
                    exit()



    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')

    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)

    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'find_retrain':
        cfg.pattern_num   = int(args.find_retrain_para.split('_')[0])
        cfg.pattern_shape = [int(args.find_retrain_para.split('_')[1]), int(args.find_retrain_para.split('_')[2])]
        cfg.pattern_nnz   = int(args.find_retrain_para.split('_')[3])
        cfg.coo_num       = float(args.find_retrain_para.split('_')[4])
        cfg.layer_or_model_wise   = str(args.find_retrain_para.split('_')[5])
        # cfg.fd_rtn_pattern_candidates = generate_complete_pattern_set(
        #                                 cfg.pattern_shape, cfg.pattern_nnz)
        print(f'find_retrain {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pattern_nnz} {cfg.coo_num} {cfg.layer_or_model_wise}')

    elif cfg.sparse_mode == 'hcgs_pruning':
        print(args.pattern_para)
        cfg.block_shape = [int(args.hcgs_para.split('_')[0]), int(args.hcgs_para.split('_')[1])]
        cfg.reserve_num1 = int(args.hcgs_para.split('_')[2])
        cfg.reserve_num2 = int(args.hcgs_para.split('_')[3])
        print(f'hcgs_pruning {cfg.reserve_num1}/8 {cfg.reserve_num2}/16')
        cfg.hcgs_mask = generate_hcgs_mask(model, cfg.block_shape, cfg.reserve_num1, cfg.reserve_num2)

    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" \
                and name.split(".")[-2] != "bn2" \
                and name.split(".")[-2] != "bn3" \
                and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [8,8])
        patterns = list(pattern_count_dict.keys())
        counts = list(pattern_count_dict.values())
        print(len(patterns))
        print(counts)
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=27)
    # loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)
    if args.test_acc == True:
        f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()

    if args.test_acc_cmodel == True:
        f1, val_loss, tps, preds, poses = test_acc_cmodel(val_loader, model, loss_fn)
        # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()
    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
Beispiel #7
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, iters_per_eval, batch_size, seed, checkpoint_path, log_dir, ema_decay=0.9999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    if train_data_config["no_chunks"]:
        criterion = MaskedCrossEntropyLoss()
    else:
        criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()
    ema = ExponentialMovingAverage(ema_decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=200000, gamma=0.5)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, scheduler, iteration, ema = load_checkpoint(checkpoint_path, model,
                                                                      optimizer, scheduler, ema)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2SampOnehot(audio_config=audio_config, verbose=True, **train_data_config)
    validset = Mel2SampOnehot(audio_config=audio_config, verbose=False, **valid_data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    valid_sampler = DistributedSampler(validset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    print(train_data_config)
    if train_data_config["no_chunks"]:
        collate_fn = utils.collate_fn
    else:
        collate_fn = torch.utils.data.dataloader.default_collate
    train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
                              collate_fn=collate_fn,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)
    valid_loader = DataLoader(validset, num_workers=1, shuffle=False,
                              sampler=valid_sampler, batch_size=1, pin_memory=True)
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    writer = SummaryWriter(log_dir)
    print("Checkpoints writing to: {}".format(log_dir))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            if low_memory:
                torch.cuda.empty_cache()
            scheduler.step()
            model.zero_grad()

            if train_data_config["no_chunks"]:
                x, y, seq_lens = batch
                seq_lens = to_gpu(seq_lens)
            else:
                x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs
            y_pred = model(x)
            if train_data_config["no_chunks"]:
                loss = criterion(y_pred, y, seq_lens)
            else:
                loss = criterion(y_pred, y)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus)[0]
            else:
                reduced_loss = loss.data[0]
            loss.backward()
            optimizer.step()

            for name, param in model.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if rank == 0:
                writer.add_scalar('loss', reduced_loss, iteration)
            if (iteration % iters_per_checkpoint == 0 and iteration):
                if rank == 0:
                    checkpoint_path = "{}/wavenet_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, scheduler, learning_rate, iteration,
                                    checkpoint_path, ema, wavenet_config)
            if (iteration % iters_per_eval == 0 and iteration > 0 and not config["no_validation"]):
                if low_memory:
                    torch.cuda.empty_cache()
                if rank == 0:
                    model_eval = nv_wavenet.NVWaveNet(**(model.export_weights()))
                    for j, valid_batch in enumerate(valid_loader):
                        mel, audio = valid_batch
                        mel = to_gpu(mel).float()
                        cond_input = model.get_cond_input(mel)
                        predicted_audio = model_eval.infer(cond_input, nv_wavenet.Impl.AUTO)
                        predicted_audio = utils.mu_law_decode_numpy(predicted_audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid/predicted_audio_{}".format(j),
                                         predicted_audio,
                                         iteration,
                                         22050)
                        audio = utils.mu_law_decode_numpy(audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid_true/audio_{}".format(j),
                                         audio,
                                         iteration,
                                         22050)
                        if low_memory:
                            torch.cuda.empty_cache()
            iteration += 1
Beispiel #8
0
class AutoencoderTrainer:
    """training the autoencoder for the first step of training"""
    def __init__(self, args):
        self.args = args
        self.data = [Dataset(args, domain_path) for domain_path in args.data]
        self.expPath = args.checkpoint / 'Autoencoder' / args.exp_name
        if not self.expPath.exists():
            self.expPath.mkdir(parents=True, exist_ok=True)
        self.logger = train_logger(self.args, self.expPath)
        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        #seed
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        #modules
        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        self.discriminator = self.discriminator.to(self.device)

        #distributed
        if args.world_size > 1:
            self.encoder = torch.nn.parallel.DistributedDataParallel(
                self.encoder,
                device_ids=[torch.cuda.current_device()],
                output_device=torch.cuda.current_device())
            self.discriminator = torch.nn.parallel.DistributedDataParallel(
                self.discriminator,
                device_ids=[torch.cuda.current_device()],
                output_device=torch.cuda.current_device())
            self.decoder = torch.nn.parallel.DistributedDataParallel(
                self.decoder,
                device_ids=[torch.cuda.current_device()],
                output_device=torch.cuda.current_device())

        #losses
        self.reconstruction_loss = [
            LossManager(f'train reconstruction {i}')
            for i in range(len(self.data))
        ]
        self.discriminator_loss = LossManager('train discriminator')
        self.total_loss = LossManager('train total')

        self.reconstruction_val = [
            LossManager(f'validation reconstruction {i}')
            for i in range(len(self.data))
        ]
        self.discriminator_val = LossManager('validation discriminator')
        self.total_val = LossManager('validation total')

        #optimizers
        self.autoenc_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                  self.decoder.parameters()),
                                            lr=args.lr)
        self.discriminator_optimizer = optim.Adam(
            self.discriminator.parameters(), lr=args.lr)

        #resume training
        if args.resume:
            checkpoint_args_file = self.expPath / 'args.pth'
            checkpoint_args = torch.load(checkpoint_args_file)

            last_epoch = checkpoint_args[-1]
            self.start_epoch = last_epoch + 1

            checkpoint_state_file = self.expPath / f'lastmodel_{last_epoch}.pth'
            states = torch.load(args.checkpoint_state_file)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            if (args.load_optimizer):
                self.autoenc_optimizer.load_state_dict(
                    states['autoenc_optimizer_state'])
                self.discriminator_optimizer.load_state_dict(
                    states['discriminator_optimizer_state'])
            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        #learning rates
        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def train_epoch(self, epoch):
        #modules
        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        #losses
        for lm in self.reconstruction_loss:
            lm.reset()
        self.discriminator_loss.reset()
        self.total_loss.reset()

        total_batches = self.args.epoch_length // self.args.batch_size

        with tqdm(total=total_batches,
                  desc=f'Train epoch {epoch}') as train_enum:
            for batch_num in range(total_batches):
                if self.args.world_size > 1:
                    dataset_no = self.args.rank
                else:
                    dataset_no = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dataset_no].train_iter)

                x = x.to(self.device)
                x_aug = x_aug.to(self.device)

                x, x_aug = x.float(), x_aug.float()

                # Train discriminator
                z = self.encoder(x)
                z_logits = self.discriminator(z)
                discriminator_loss = F.cross_entropy(
                    z_logits,
                    torch.tensor([dataset_no] *
                                 x.size(0)).long().cuda()).mean()
                loss = discriminator_loss * self.args.d_weight
                self.discriminator_optimizer.zero_grad()
                loss.backward()
                if self.args.grad_clip is not None:
                    clip_grad_value_(self.discriminator.parameters(),
                                     self.args.grad_clip)
                self.discriminator_optimizer.step()

                # Train autoencoder
                z = self.encoder(x_aug)
                y = self.decoder(x, z)
                z_logits = self.discriminator(z)
                discriminator_loss = -F.cross_entropy(
                    z_logits,
                    torch.tensor(
                        [dataset_no] * x.size(0)).long().cuda()).mean()
                reconstruction_loss = cross_entropy_loss(y, x)
                self.reconstruction_loss[dataset_no].add(
                    reconstruction_loss.data.cpu().numpy().mean())
                loss = (reconstruction_loss.mean() +
                        self.args.d_weight * discriminator_loss)
                self.model_optimizer.zero_grad()
                loss.backward()

                if self.args.grad_clip is not None:
                    clip_grad_value_(self.encoder.parameters(),
                                     self.args.grad_clip)
                    clip_grad_value_(self.decoder.parameters(),
                                     self.args.grad_clip)
                self.model_optimizer.step()

                self.loss_total.add(loss.data.item())

                train_enum.set_description(
                    f'Train (loss: {loss.data.item():.2f}) epoch {epoch}')
                train_enum.update()

    def validate_epoch(self, epoch):

        #modules
        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        #losses
        for lm in self.reconstruction_val:
            lm.reset()
        self.discriminator_val.reset()
        self.total_val.reset()

        total_batches = self.args.epoch_length // self.args.batch_size // 10

        with tqdm(total=total_batches) as valid_enum, torch.no_grad():
            for batch_num in range(total_batches):

                if self.args.world_size > 1:
                    dataset_no = self.args.rank
                else:
                    dataset_no = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dataset_no].valid_iter)

                x = x.to(self.device)
                x_aug = x.to(self.device)

                x, x_aug = x.float(), x_aug.float()

                z = self.encoder(x)
                y = self.decoder(x, z)
                z_logits = self.discriminator(z)
                z_classification = torch.max(z_logits, dim=1)[1]
                z_accuracy = (z_classification == dataset_no).float().mean()

                self.discriminator_val.add(z_accuracy.data.item())

                # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
                discriminator_right = F.cross_entropy(
                    z_logits,
                    torch.tensor([dataset_no] *
                                 x.size(0)).long().cuda()).mean()
                recon_loss = cross_entropy_loss(y, x)

                self.evals_recon[dataset_no].add(
                    recon_loss.data.cpu().numpy().mean())

                total_loss = discriminator_right.data.item(
                ) * self.args.d_lambda + recon_loss.mean().data.item()

                self.total_val.add(total_loss)

                valid_enum.set_description(
                    f'Test (loss: {total_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    def train(self):
        best_loss = float('inf')
        for epoch in range(self.start_epoch, self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.validate_epoch(epoch)

            train_losses = [self.reconstruction_loss, self.discriminator_loss]
            val_losses = [self.reconstruction_val, self.discriminator_val]

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Validation loss (%s)',
                epoch, train_losses, val_losses)

            mean_loss = self.val_total.epoch_mean()
            if mean_loss < best_loss:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_loss = mean_loss

            if self.args.save_model:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{self.args.rank}.pth')

        # if self.args.rank:
        #    torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.lr_manager.step()
            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_to = self.expPath / filename
        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'discriminator_state': self.discriminator.module.state_dict(),
                'autoenc_optimizer_state': self.autoenc_optimizer.state_dict(),
                'd_optimizer_state': self.discriminator_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_to)

        self.logger.debug(f'Saved model to {save_to}')
Beispiel #9
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, batch_size, seed, checkpoint_path):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cpu()

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    print(f"receptive_field: {model.receptive_field()}")
    trainset = WavenetDataset(
        dataset_file='data/dataset.npz',
        item_length=model.receptive_field() + 1000 + model.output_length - 1,
        target_length=model.output_length,
        file_location='data/',
        test_stride=500,
    )
    print(trainset._length)
    print('the dataset has ' + str(len(trainset)) + ' items')
    train_loader = DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=False,
    )

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    start = time.time()
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()
            y, target = batch
            y = to_gpu(y).float()
            target = to_gpu(target)
            y_pred = model((None, y))
            loss = criterion(y_pred[:, :, -model.output_length:], target)
            loss.backward()
            optimizer.step()

            print("{}:\t{:.9f}".format(iteration, loss))
            print_etr(start,
                      total_iterations=(epochs - epoch_offset) *
                      len(train_loader),
                      current_iteration=epoch * len(train_loader) + i + 1)
            writer.add_scalar('Loss/train', loss, global_step=iteration)

            if (iteration % iters_per_checkpoint == 0):
                y_choice = y_pred[0].detach().cpu().transpose(0, 1)
                y_prob = F.softmax(y_choice, dim=1)
                y_prob_collapsed = torch.multinomial(y_prob,
                                                     num_samples=1).squeeze(1)
                y_pred_audio = mu_law_decode_numpy(y_prob_collapsed.numpy(),
                                                   model.n_out_channels)
                import torchaudio
                y_audio = mu_law_decode_numpy(y.numpy(), model.n_out_channels)
                torchaudio.save("test_in.wav", torch.tensor(y_audio), 16000)
                torchaudio.save("test_out.wav", torch.tensor(y_pred_audio),
                                16000)
                writer.add_audio('Audio',
                                 y_pred_audio,
                                 global_step=iteration,
                                 sample_rate=data_config['sampling_rate'])
                checkpoint_path = "{}/wavenet_{}".format(
                    output_directory, iteration)
                save_checkpoint(model, optimizer, learning_rate, iteration,
                                checkpoint_path)

            writer.flush()
            iteration += 1
Beispiel #10
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.work_root   = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir     = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from
    cfg.save_excel   = args.save_excel

    if args.skip_exist == True:
        if os.path.exists(cfg.workdir):
            exit()

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True)

    # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True)
    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()


    name_list = list()
    para_list = list()
    for name, para in model.named_parameters():
        name_list.append(name)
        para_list.append(para)

    a = model.state_dict()
    for i, name in enumerate(name_list):
        if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
            raw_w = para_list[i]
            nn.init.xavier_normal_(raw_w, gain=1.0)
            a[name] = raw_w
    model.load_state_dict(a)
    

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    if args.vis_pattern == True or args.vis_mask == True:
        cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=False)
        print("loading", cfg.workdir + '/weights/best.pth')

    if os.path.exists(cfg.load_from):
        model.load_state_dict(torch.load(cfg.load_from), strict=False)
        print("loading", cfg.load_from)

    if os.path.exists(args.load_from_h5):
        # model.load_state_dict(torch.load(args.load_from_h5), strict=True)
        print("loading", args.load_from_h5)
        model.train()
        model_dict = model.state_dict()
        print(model_dict.keys())
        #先将参数值numpy转换为tensor形式
        pretrained_dict = dd.io.load(args.load_from_h5)
        print(pretrained_dict.keys())
        new_pre_dict = {}
        for k,v in pretrained_dict.items():
            new_pre_dict[k] = torch.Tensor(v)
        #更新
        model_dict.update(new_pre_dict)
        #加载
        model.load_state_dict(model_dict)

    if args.find_pattern == True:

        cfg.find_pattern_num   = 16
        cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])

        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        a = model.state_dict()
        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]
                if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                    patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \
                                    = find_pattern_by_similarity(raw_w
                                        , cfg.find_pattern_num
                                        , cfg.find_pattern_shape
                                        , cfg.find_zero_threshold
                                        , cfg.find_score_threshold)

                    pattern_num_memory_dict, pattern_num_coo_nnz_dict \
                                    = pattern_curve_analyse(raw_w.shape
                                        , cfg.find_pattern_shape
                                        , patterns
                                        , pattern_match_num
                                        , pattern_coo_nnz
                                        , pattern_nnz
                                        , pattern_inner_nnz)
                                        
                    write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                                        , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para
                                        , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz
                                        , pattern_num_memory_dict, pattern_num_coo_nnz_dict)

                    # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                    #                     , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para
                    #                     , all_nnzs.values(), all_patterns.values())
                    exit()



    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')

    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)

    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')


    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [8,8])
        patterns = list(pattern_count_dict.keys())
        counts = list(pattern_count_dict.values())
        print(len(patterns))
        print(counts)
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=27)
    # loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)
    if args.test_acc == True:
        f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()
    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
    hs = encoder(xs)  # bs x ch_in * 3 x seqlen
    return hs.view(bs, -1, CH_INPUT, LEN_INPUT)\
             .permute(0, 2, 3, 1).contiguous()  # bs x ch_in x seqlen x nclass

    # sign_l = hs[:, :CH_INPUT, :]
    # mag_l = hs[:, CH_INPUT:, :]

    # return sign_l, mag_l

    # percent_resid = encoder(xs)
    # ys = F.softplus(xs + (xs * percent_resid))
    # ys.percent_resid = percent_resid
    # return ys


params = encoder.parameters()
# params = list(clf.parameters()) + list(encoder.parameters())
opt = Adam(params, lr=1e-2)  #, weight_decay=None)

mse = nn.MSELoss()
bcel = nn.BCEWithLogitsLoss()
cel = nn.CrossEntropyLoss()

MAX_SIGN_LOSS = nn.BCELoss()(Variable(torch.zeros(1)), Variable(torch.ones(1)))


def step(xs, ys, i_step):
    xs = Variable(xs)
    ys = Variable(ys).contiguous()
    opt.zero_grad()
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of aux feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    parser.add_argument("--feature_type",
                        default="world",
                        choices=["world", "melspc"],
                        type=str,
                        help="feature type")
    # network structure setting
    parser.add_argument("--n_quantize",
                        default=256,
                        type=int,
                        help="number of quantization")
    parser.add_argument("--n_aux",
                        default=28,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--n_resch",
                        default=512,
                        type=int,
                        help="number of channels of residual output")
    parser.add_argument("--n_skipch",
                        default=256,
                        type=int,
                        help="number of channels of skip output")
    parser.add_argument("--dilation_depth",
                        default=10,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=1,
                        type=int,
                        help="number of repeating of dilation")
    parser.add_argument("--kernel_size",
                        default=2,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor",
                        default=80,
                        type=int,
                        help="upsampling factor of aux features")
    parser.add_argument("--use_upsampling_layer",
                        default=True,
                        type=strtobool,
                        help="flag to use upsampling layer")
    parser.add_argument("--use_speaker_code",
                        default=False,
                        type=strtobool,
                        help="flag to use speaker code")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="weight decay coefficient")
    parser.add_argument(
        "--batch_length",
        default=20000,
        type=int,
        help="batch length (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--batch_size",
        default=1,
        type=int,
        help="batch size (if use utterance batch, batch_size will be 1.")
    parser.add_argument("--iters",
                        default=200000,
                        type=int,
                        help="number of iterations")
    # other setting
    parser.add_argument("--checkpoints",
                        default=10000,
                        type=int,
                        help="how frequent saving model")
    parser.add_argument("--intervals",
                        default=100,
                        type=int,
                        help="log interval")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        nargs="?",
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--n_gpus", default=1, type=int, help="number of gpus")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
    else:
        logging.basicConfig(
            level=logging.WARNING,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
        logging.warning("logging is disabled.")

    # show argmument
    for key, value in vars(args).items():
        logging.info("%s = %s" % (key, str(value)))

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # fix slow computation of dilated conv
    # https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923
    torch.backends.cudnn.benchmark = True

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    if args.use_upsampling_layer:
        upsampling_factor = args.upsampling_factor
    else:
        upsampling_factor = 0
    model = WaveNet(n_quantize=args.n_quantize,
                    n_aux=args.n_aux,
                    n_resch=args.n_resch,
                    n_skipch=args.n_skipch,
                    dilation_depth=args.dilation_depth,
                    dilation_repeat=args.dilation_repeat,
                    kernel_size=args.kernel_size,
                    upsampling_factor=upsampling_factor)
    logging.info(model)
    model.apply(initialize)
    model.train()

    if args.n_gpus > 1:
        device_ids = range(args.n_gpus)
        model = torch.nn.DataParallel(model, device_ids)
        model.receptive_field = model.module.receptive_field
        if args.n_gpus > args.batch_size:
            logging.warning("batch size is less than number of gpus.")

    # define optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    scaler = StandardScaler()
    scaler.mean_ = read_hdf5(args.stats, "/" + args.feature_type + "/mean")
    scaler.scale_ = read_hdf5(args.stats, "/" + args.feature_type + "/scale")
    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])
    feat_transform = transforms.Compose([lambda x: scaler.transform(x)])

    # define generator
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [
            args.feats + "/" + filename.replace(".wav", ".h5")
            for filename in filenames
        ]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    generator = train_generator(wav_list,
                                feat_list,
                                receptive_field=model.receptive_field,
                                batch_length=args.batch_length,
                                batch_size=args.batch_size,
                                feature_type=args.feature_type,
                                wav_transform=wav_transform,
                                feat_transform=feat_transform,
                                shuffle=True,
                                upsampling_factor=args.upsampling_factor,
                                use_upsampling_layer=args.use_upsampling_layer,
                                use_speaker_code=args.use_speaker_code)

    # charge minibatch in queue
    while not generator.queue.full():
        time.sleep(0.1)

    # resume model and optimizer
    if args.resume is not None and len(args.resume) != 0:
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        iterations = checkpoint["iterations"]
        if args.n_gpus > 1:
            model.module.load_state_dict(checkpoint["model"])
        else:
            model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        logging.info("restored from %d-iter checkpoint." % iterations)
    else:
        iterations = 0

    # check gpu and then send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
        for state in optimizer.state.values():
            for key, value in state.items():
                if torch.is_tensor(value):
                    state[key] = value.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    # train
    loss = 0
    total = 0
    for i in six.moves.range(iterations, args.iters):
        start = time.time()
        (batch_x, batch_h), batch_t = generator.next()
        batch_output = model(batch_x, batch_h)
        batch_loss = criterion(
            batch_output[:, model.receptive_field:].contiguous().view(
                -1, args.n_quantize),
            batch_t[:, model.receptive_field:].contiguous().view(-1))
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.item()
        total += time.time() - start
        logging.debug("batch loss = %.3f (%.3f sec / batch)" %
                      (batch_loss.item(), time.time() - start))

        # report progress
        if (i + 1) % args.intervals == 0:
            logging.info(
                "(iter:%d) average loss = %.6f (%.3f sec / batch)" %
                (i + 1, loss / args.intervals, total / args.intervals))
            logging.info(
                "estimated required time = "
                "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}".
                format(
                    relativedelta(seconds=int((args.iters - (i + 1)) *
                                              (total / args.intervals)))))
            loss = 0
            total = 0

        # save intermidiate model
        if (i + 1) % args.checkpoints == 0:
            if args.n_gpus > 1:
                save_checkpoint(args.expdir, model.module, optimizer, i + 1)
            else:
                save_checkpoint(args.expdir, model, optimizer, i + 1)

    # save final model
    if args.n_gpus > 1:
        torch.save({"model": model.module.state_dict()},
                   args.expdir + "/checkpoint-final.pkl")
    else:
        torch.save({"model": model.state_dict()},
                   args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
Beispiel #13
0
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=1, pin_memory=True)

    ### Featurizer
    featurizer = MelSpectrogram(MelSpectrogramConfig(), device).to(device)
    ### Model
    model = WaveNet(hidden_ch=120, skip_ch=240, num_layers=30, mu=256)
    model = model.to(device)
    # wandb
    wandb.init(project='wavenet-pytorch')
    wandb.watch(model)
    print('num of model parameters', count_parameters(model))
    ### Optimizer
    opt = torch.optim.Adam(model.parameters(), lr=3e-4)
    ### Encoder and decoder for mu-law
    mu_law_encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=256).to(device)
    mu_law_decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=256).to(device)

    ### Train loop
    for i in tqdm(range(NUM_EPOCHS)):
        for el in train_loader:
            wav = el['audio'].to(device)
            melspec = featurizer(wav)
            wav = mu_law_encoder(wav).unsqueeze(1).type(torch.float)

            opt.zero_grad()

            new_wav = model(melspec, wav[:, :, :-1])
Beispiel #14
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.workdir     = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True)

    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'))
        print("loading", cfg.workdir + '/weights/best.pth')

    if os.path.exists(cfg.load_from):
        model.load_state_dict(torch.load(cfg.load_from))
        print("loading", cfg.load_from)


    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')
    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')


    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [16,16])
        patterns = list(pattern_count_dict.keys())
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=0, reduction='none')

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)

    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of aux feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    # network structure setting
    parser.add_argument("--n_quantize",
                        default=256,
                        type=int,
                        help="number of quantization")
    parser.add_argument("--n_aux",
                        default=28,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--n_resch",
                        default=512,
                        type=int,
                        help="number of channels of residual output")
    parser.add_argument("--n_skipch",
                        default=256,
                        type=int,
                        help="number of channels of skip output")
    parser.add_argument("--dilation_depth",
                        default=10,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=1,
                        type=int,
                        help="number of repeating of dilation")
    parser.add_argument("--kernel_size",
                        default=2,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor",
                        default=0,
                        type=int,
                        help="upsampling factor of aux features"
                        "(if set 0, do not apply)")
    parser.add_argument("--use_speaker_code",
                        default=False,
                        type=strtobool,
                        help="flag to use speaker code")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="weight decay coefficient")
    parser.add_argument(
        "--batch_size",
        default=20000,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--iters",
                        default=200000,
                        type=int,
                        help="number of iterations")
    # other setting
    parser.add_argument("--checkpoints",
                        default=10000,
                        type=int,
                        help="how frequent saving model")
    parser.add_argument("--intervals",
                        default=100,
                        type=int,
                        help="log interval")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # # define network
    model = WaveNet(n_quantize=args.n_quantize,
                    n_aux=args.n_aux,
                    n_resch=args.n_resch,
                    n_skipch=args.n_skipch,
                    dilation_depth=args.dilation_depth,
                    dilation_repeat=args.dilation_repeat,
                    kernel_size=args.kernel_size,
                    upsampling_factor=args.upsampling_factor)
    logging.info(model)
    model.apply(initialize)
    model.train()

    # define loss and optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    scaler = StandardScaler()
    scaler.mean_ = read_hdf5(args.stats, "/mean")
    scaler.scale_ = read_hdf5(args.stats, "/scale")
    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])
    feat_transform = transforms.Compose([lambda x: scaler.transform(x)])

    # define generator
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [
            args.feats + "/" + filename.replace(".wav", ".h5")
            for filename in filenames
        ]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    generator = train_generator(wav_list,
                                feat_list,
                                receptive_field=model.receptive_field,
                                batch_size=args.batch_size,
                                wav_transform=wav_transform,
                                feat_transform=feat_transform,
                                shuffle=True,
                                upsampling_factor=args.upsampling_factor,
                                use_speaker_code=args.use_speaker_code)
    while not generator.queue.full():
        time.sleep(0.1)

    # resume
    if args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        iterations = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % iterations)
    else:
        iterations = 0

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    # train
    loss = 0
    total = 0
    for i in six.moves.range(iterations, args.iters):
        start = time.time()
        (batch_x, batch_h), batch_t = generator.next()
        batch_output = model(batch_x, batch_h)[0]
        batch_loss = criterion(batch_output[model.receptive_field:],
                               batch_t[model.receptive_field:])
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.data[0]
        total += time.time() - start
        logging.debug("batch loss = %.3f (%.3f sec / batch)" %
                      (batch_loss.data[0], time.time() - start))

        # report progress
        if (i + 1) % args.intervals == 0:
            logging.info(
                "(iter:%d) average loss = %.6f (%.3f sec / batch)" %
                (i + 1, loss / args.intervals, total / args.intervals))
            logging.info(
                "estimated required time = "
                "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}".
                format(
                    relativedelta(seconds=int((args.iters - (i + 1)) *
                                              (total / args.intervals)))))
            loss = 0
            total = 0

        # save intermidiate model
        if (i + 1) % args.checkpoints == 0:
            save_checkpoint(args.expdir, model, optimizer, i + 1)

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()},
               args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
Beispiel #16
0
import torch
from matplotlib import pyplot as plt
from torch import nn, optim
from sine import sine_generator
from wavenet import WaveNet

g = sine_generator(seq_size=2200, mu=64)
net = WaveNet(n_out=64,
              n_residue=24,
              n_skip=128,
              dilation_depth=10,
              n_layers=2)
optimizer = optim.Adam(net.parameters(), lr=0.01)
batch_size = 64

loss_save = []
max_epoch = 2000
for epoch in range(max_epoch):
    optimizer.zero_grad()
    loss = 0
    # iterate over the data
    for _ in range(batch_size):
        batch = next(g)
        x = batch[:-1]
        logits = net(x)
        sz = logits.size(0)
        loss = loss + nn.functional.cross_entropy(logits, batch[-sz:])
    loss = loss / batch_size
    loss.backward()
    optimizer.step()
    loss_save.append(loss.data[0])
Beispiel #17
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)

        assert args.checkpoint, 'you MUST pass a checkpoint for the encoder'

        if args.continue_training:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
        else:
            self.start_epoch = 0

        states = torch.load(args.checkpoint)
        self.encoder.load_state_dict(states['encoder_state'])
        if args.continue_training:
            self.decoder.load_state_dict(states['decoder_state'])
        self.logger.info('Loaded checkpoint parameters')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(self.decoder.parameters(),
                                          lr=args.lr)

        if args.continue_training:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # optimize G - reconstructs well
        z = self.encoder(x_aug)
        z = z.detach()  # stop gradients
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = recon_loss.mean()
        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.encoder.eval()
        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        states = torch.load(self.args.checkpoint)

        torch.save(
            {
                'encoder_state': states['encoder_state'],
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
Beispiel #18
0
class Finetuner:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(args.data)
        self.modelPath = Path('checkpoints') / args.expName

        self.logger = create_output_dir(args, self.modelPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.start_epoch = 0

        #torch.manual_seed(args.seed)
        #torch.cuda.manual_seed(args.seed)

        #get the pretrained model checkpoints
        checkpoint = args.checkpoint.parent.glob(args.checkpoint.name +
                                                 '_*.pth')
        checkpoint = [c for c in checkpoint
                      if extract_id(c) in args.decoder][0]

        model_args = torch.load(args.checkpoint.parent / 'args.pth')[0]

        self.encoder = Encoder(model_args)
        self.decoder = WaveNet(model_args)

        self.encoder = Encoder(model_args)
        self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state'])

        #encoder freeze
        for param in self.encoder.parameters():
            param.requires_grad = False
            #self.logger.debug(f'encoder at start: {param}')

        self.decoder = WaveNet(model_args)
        self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])

        #decoder freeze
        for param in self.decoder.layers[:-args.decoder_update].parameters():
            param.requires_grad = False
            #self.logger.debug(f'decoder at start: {param}')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()
        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.step()

    def train_batch(self, x, x_aug, dset_num):
        'train batch without considering the discriminator'
        x = x.float()
        x_aug = x_aug.float()
        z = self.encoder(x_aug)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
        loss = recon_loss.mean()

        self.model_optimizer.zero_grad()
        loss.backward()
        self.model_optimizer.step()
        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()
        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def finetune(self):
        best_eval = float('inf')

        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args, epoch], '%s/args.pth' % self.modelPath)
            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.modelPath / filename

        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')