示例#1
0
    def gpu_decode(feat_list, gpu):
        # set default gpu and do not track gradient
        torch.cuda.set_device(gpu)
        torch.set_grad_enabled(False)

        # define model and load parameters
        if config.use_upsampling_layer:
            upsampling_factor = config.upsampling_factor
        else:
            upsampling_factor = 0
        model = WaveNet(n_quantize=config.n_quantize,
                        n_aux=config.n_aux,
                        n_resch=config.n_resch,
                        n_skipch=config.n_skipch,
                        dilation_depth=config.dilation_depth,
                        dilation_repeat=config.dilation_repeat,
                        kernel_size=config.kernel_size,
                        upsampling_factor=upsampling_factor)
        model.load_state_dict(
            torch.load(args.checkpoint,
                       map_location=lambda storage, loc: storage)["model"])
        model.eval()
        model.cuda()

        # define generator
        generator = decode_generator(
            feat_list,
            batch_size=args.batch_size,
            feature_type=config.feature_type,
            wav_transform=wav_transform,
            feat_transform=feat_transform,
            upsampling_factor=config.upsampling_factor,
            use_upsampling_layer=config.use_upsampling_layer,
            use_speaker_code=config.use_speaker_code)

        # decode
        if args.batch_size > 1:
            for feat_ids, (batch_x, batch_h, n_samples_list) in generator:
                logging.info("decoding start")
                samples_list = model.batch_fast_generate(
                    batch_x, batch_h, n_samples_list, args.intervals)
                for feat_id, samples in zip(feat_ids, samples_list):
                    wav = decode_mu_law(samples, config.n_quantize)
                    sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                             args.fs, "PCM_16")
                    logging.info("wrote %s.wav in %s." %
                                 (feat_id, args.outdir))
        else:
            for feat_id, (x, h, n_samples) in generator:
                logging.info("decoding %s (length = %d)" %
                             (feat_id, n_samples))
                samples = model.fast_generate(x, h, n_samples, args.intervals)
                wav = decode_mu_law(samples, config.n_quantize)
                sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs,
                         "PCM_16")
                logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
示例#2
0
def main():
    args = parse_args()
    cfg.resume = args.resume
    cfg.exp_name = args.exp
    cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size = args.batch_size
    cfg.lr = args.lr
    cfg.load_from = args.load_from
    cfg.save_excel = args.save_excel

    weights_dir = os.path.join(cfg.workdir, 'weights')
    check_and_mkdir(weights_dir)

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir + '/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train,
                              batch_size=cfg.batch_size,
                              num_workers=4,
                              shuffle=True,
                              pin_memory=True)
    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val,
                            batch_size=cfg.batch_size,
                            num_workers=4,
                            shuffle=False,
                            pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=40, dilations=[1, 2, 4, 8, 16])
    model = nn.DataParallel(model)
    model.cuda()
    model.train()

    # build loss
    loss_fn = nn.CTCLoss(blank=27)

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'),
                              strict=True)
        print("loading", cfg.workdir + '/weights/best.pth')
        cfg.load_from = cfg.workdir + '/weights/best.pth'

    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
示例#3
0
    def gpu_decode(feat_list, gpu):
        with torch.cuda.device(gpu):
            # define model and load parameters
            model = WaveNet(n_quantize=config.n_quantize,
                            n_aux=config.n_aux,
                            n_resch=config.n_resch,
                            n_skipch=config.n_skipch,
                            dilation_depth=config.dilation_depth,
                            dilation_repeat=config.dilation_repeat,
                            kernel_size=config.kernel_size,
                            upsampling_factor=config.upsampling_factor)
            model.load_state_dict(
                torch.load(args.checkpoint,
                           map_location=lambda storage, loc: storage.cuda(gpu))
                ["model"])
            model.eval()
            model.cuda()
            torch.backends.cudnn.benchmark = True

            # define generator
            generator = decode_generator(
                feat_list,
                batch_size=args.batch_size,
                wav_transform=wav_transform,
                feat_transform=feat_transform,
                use_speaker_code=config.use_speaker_code,
                upsampling_factor=config.upsampling_factor)

            # decode
            if args.batch_size > 1:
                for feat_ids, (batch_x, batch_h, n_samples_list) in generator:
                    logging.info("decoding start")
                    samples_list = model.batch_fast_generate(
                        batch_x, batch_h, n_samples_list, args.intervals)
                    for feat_id, samples in zip(feat_ids, samples_list):
                        wav = decode_mu_law(samples, config.n_quantize)
                        sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                                 args.fs, "PCM_16")
                        logging.info("wrote %s.wav in %s." %
                                     (feat_id, args.outdir))
            else:
                for feat_id, (x, h, n_samples) in generator:
                    logging.info("decoding %s (length = %d)" %
                                 (feat_id, n_samples))
                    samples = model.fast_generate(x, h, n_samples,
                                                  args.intervals)
                    wav = decode_mu_law(samples, config.n_quantize)
                    sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                             args.fs, "PCM_16")
                    logging.info("wrote %s.wav in %s." %
                                 (feat_id, args.outdir))
示例#4
0
def main(args):
    print('Starting')
    matplotlib.use('agg')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(args.model.parent / 'args.pth')[0]
    encoder = wavenet_models.Encoder(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if args.py:
            decoder = WavenetGenerator(decoder,
                                       args.batch_size,
                                       wav_freq=args.rate)
        else:
            decoder = NVWavenetGenerator(decoder,
                                         args.rate * (args.split_size // 20),
                                         args.batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert args.output_next_to_orig ^ (args.output is not None)

    if len(args.files) == 1 and args.files[0].is_dir():
        top = args.files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = args.files

    if not args.skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = utils.mu_law(data)
        elif file_path.suffix == '.h5':
            data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % args.rate != 0:
                data = data[:-(data.shape[-1] % args.rate)]
            assert data.shape[-1] % args.rate == 0
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if args.sample_len:
            data = data[:args.sample_len]
        else:
            args.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = utils.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if args.output_next_to_orig:
            save_audio(wav.squeeze(),
                       filepath.parent / f'{filepath.stem}_{decoder_ix}.wav',
                       rate=args.rate)
        else:
            save_audio(wav.squeeze(),
                       args.output / str(extract_id(args.model)) /
                       str(args.update) / filepath.with_suffix('.wav').name,
                       rate=args.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, args.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        with utils.timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for zz_batch in torch.split(zz, args.batch_size):
                    print(zz_batch.shape)
                    splits = torch.split(zz_batch, args.split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)
示例#5
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of aux feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    # network structure setting
    parser.add_argument("--n_quantize",
                        default=256,
                        type=int,
                        help="number of quantization")
    parser.add_argument("--n_aux",
                        default=28,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--n_resch",
                        default=512,
                        type=int,
                        help="number of channels of residual output")
    parser.add_argument("--n_skipch",
                        default=256,
                        type=int,
                        help="number of channels of skip output")
    parser.add_argument("--dilation_depth",
                        default=10,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=1,
                        type=int,
                        help="number of repeating of dilation")
    parser.add_argument("--kernel_size",
                        default=2,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor",
                        default=0,
                        type=int,
                        help="upsampling factor of aux features"
                        "(if set 0, do not apply)")
    parser.add_argument("--use_speaker_code",
                        default=False,
                        type=strtobool,
                        help="flag to use speaker code")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="weight decay coefficient")
    parser.add_argument(
        "--batch_size",
        default=20000,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--iters",
                        default=200000,
                        type=int,
                        help="number of iterations")
    # other setting
    parser.add_argument("--checkpoints",
                        default=10000,
                        type=int,
                        help="how frequent saving model")
    parser.add_argument("--intervals",
                        default=100,
                        type=int,
                        help="log interval")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # # define network
    model = WaveNet(n_quantize=args.n_quantize,
                    n_aux=args.n_aux,
                    n_resch=args.n_resch,
                    n_skipch=args.n_skipch,
                    dilation_depth=args.dilation_depth,
                    dilation_repeat=args.dilation_repeat,
                    kernel_size=args.kernel_size,
                    upsampling_factor=args.upsampling_factor)
    logging.info(model)
    model.apply(initialize)
    model.train()

    # define loss and optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    scaler = StandardScaler()
    scaler.mean_ = read_hdf5(args.stats, "/mean")
    scaler.scale_ = read_hdf5(args.stats, "/scale")
    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])
    feat_transform = transforms.Compose([lambda x: scaler.transform(x)])

    # define generator
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [
            args.feats + "/" + filename.replace(".wav", ".h5")
            for filename in filenames
        ]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    generator = train_generator(wav_list,
                                feat_list,
                                receptive_field=model.receptive_field,
                                batch_size=args.batch_size,
                                wav_transform=wav_transform,
                                feat_transform=feat_transform,
                                shuffle=True,
                                upsampling_factor=args.upsampling_factor,
                                use_speaker_code=args.use_speaker_code)
    while not generator.queue.full():
        time.sleep(0.1)

    # resume
    if args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        iterations = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % iterations)
    else:
        iterations = 0

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    # train
    loss = 0
    total = 0
    for i in six.moves.range(iterations, args.iters):
        start = time.time()
        (batch_x, batch_h), batch_t = generator.next()
        batch_output = model(batch_x, batch_h)[0]
        batch_loss = criterion(batch_output[model.receptive_field:],
                               batch_t[model.receptive_field:])
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.data[0]
        total += time.time() - start
        logging.debug("batch loss = %.3f (%.3f sec / batch)" %
                      (batch_loss.data[0], time.time() - start))

        # report progress
        if (i + 1) % args.intervals == 0:
            logging.info(
                "(iter:%d) average loss = %.6f (%.3f sec / batch)" %
                (i + 1, loss / args.intervals, total / args.intervals))
            logging.info(
                "estimated required time = "
                "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}".
                format(
                    relativedelta(seconds=int((args.iters - (i + 1)) *
                                              (total / args.intervals)))))
            loss = 0
            total = 0

        # save intermidiate model
        if (i + 1) % args.checkpoints == 0:
            save_checkpoint(args.expdir, model, optimizer, i + 1)

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()},
               args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
示例#6
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.work_root   = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir     = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from
    cfg.save_excel   = args.save_excel        
    
    if args.find_pattern == True:
        cfg.find_pattern_num   = 16
        cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])
        if int(cfg.find_pattern_shape[0] * cfg.find_pattern_shape[1]) <= cfg.find_score_threshold:
            exit()

    if args.skip_exist == True:
        if os.path.exists(cfg.workdir):
            exit()

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True)

    # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True)
    vctk_val = VCTK(cfg, 'val')
    if args.test_acc_cmodel == True:
        val_loader = DataLoader(vctk_val, batch_size=1, num_workers=4, shuffle=False, pin_memory=True)
    else:
        val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=40, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()


    name_list = list()
    para_list = list()
    for name, para in model.named_parameters():
        name_list.append(name)
        para_list.append(para)

    a = model.state_dict()
    for i, name in enumerate(name_list):
        if name.split(".")[-2] != "bn" \
            and name.split(".")[-2] != "bn2" \
            and name.split(".")[-2] != "bn3" \
            and name.split(".")[-1] != "bias":
            raw_w = para_list[i]
            nn.init.xavier_normal_(raw_w, gain=1.0)
            a[name] = raw_w
    model.load_state_dict(a)
    

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    if args.vis_pattern == True or args.vis_mask == True:
        cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True)
        print("loading", cfg.workdir + '/weights/best.pth')
        cfg.load_from = cfg.workdir + '/weights/best.pth'

    if args.test_acc == True:
        if os.path.exists(cfg.load_from):
            model.load_state_dict(torch.load(cfg.load_from), strict=True)
            print("loading", cfg.load_from)
        else:
            print("Error: model file not exists, ", cfg.load_from)
            exit()
    else:
        if os.path.exists(cfg.load_from):
            model.load_state_dict(torch.load(cfg.load_from), strict=True)
            print("loading", cfg.load_from)
            # Export the model
            print("exporting onnx ...")
            model.eval()
            batch_size = 1
            x = torch.randn(batch_size, 40, 720, requires_grad=True).cuda()
            torch.onnx.export(model.module,               # model being run
                            x,                         # model input (or a tuple for multiple inputs)
                            "wavenet.onnx",   # where to save the model (can be a file or file-like object)
                            export_params=True,        # store the trained parameter weights inside the model file
                            opset_version=10,          # the ONNX version to export the model to
                            do_constant_folding=True,  # whether to execute constant folding for optimization
                            input_names = ['input'],   # the model's input names
                            output_names = ['output'], # the model's output names
                            dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                            'output' : {0 : 'batch_size'}})

    if os.path.exists(args.load_from_h5):
        # model.load_state_dict(torch.load(args.load_from_h5), strict=True)
        print("loading", args.load_from_h5)
        model.train()
        model_dict = model.state_dict()
        print(model_dict.keys())
        #先将参数值numpy转换为tensor形式
        pretrained_dict = dd.io.load(args.load_from_h5)
        print(pretrained_dict.keys())
        new_pre_dict = {}
        for k,v in pretrained_dict.items():
            new_pre_dict[k] = torch.Tensor(v)
        #更新
        model_dict.update(new_pre_dict)
        #加载
        model.load_state_dict(model_dict)

    if args.find_pattern == True:

        # cfg.find_pattern_num   = 16
        # cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        # cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        # cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])

        # if cfg.find_pattern_shape[0] * cfg.find_pattern_shape[0] <= cfg.find_score_threshold:
        #     exit()

        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        a = model.state_dict()
        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" \
                and name.split(".")[-2] != "bn2" \
                and name.split(".")[-2] != "bn3" \
                and name.split(".")[-1] != "bias":
                raw_w = para_list[i]
                if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                    patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \
                                    = find_pattern_by_similarity(raw_w
                                        , cfg.find_pattern_num
                                        , cfg.find_pattern_shape
                                        , cfg.find_zero_threshold
                                        , cfg.find_score_threshold)

                    pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict \
                                    = pattern_curve_analyse(raw_w.shape
                                        , cfg.find_pattern_shape
                                        , patterns
                                        , pattern_match_num
                                        , pattern_coo_nnz
                                        , pattern_nnz
                                        , pattern_inner_nnz)
                                        
                    write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                                        , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para
                                        , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz
                                        , pattern_inner_nnz
                                        , pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict)

                    # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                    #                     , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para
                    #                     , all_nnzs.values(), all_patterns.values())
                    exit()



    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')

    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)

    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'find_retrain':
        cfg.pattern_num   = int(args.find_retrain_para.split('_')[0])
        cfg.pattern_shape = [int(args.find_retrain_para.split('_')[1]), int(args.find_retrain_para.split('_')[2])]
        cfg.pattern_nnz   = int(args.find_retrain_para.split('_')[3])
        cfg.coo_num       = float(args.find_retrain_para.split('_')[4])
        cfg.layer_or_model_wise   = str(args.find_retrain_para.split('_')[5])
        # cfg.fd_rtn_pattern_candidates = generate_complete_pattern_set(
        #                                 cfg.pattern_shape, cfg.pattern_nnz)
        print(f'find_retrain {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pattern_nnz} {cfg.coo_num} {cfg.layer_or_model_wise}')

    elif cfg.sparse_mode == 'hcgs_pruning':
        print(args.pattern_para)
        cfg.block_shape = [int(args.hcgs_para.split('_')[0]), int(args.hcgs_para.split('_')[1])]
        cfg.reserve_num1 = int(args.hcgs_para.split('_')[2])
        cfg.reserve_num2 = int(args.hcgs_para.split('_')[3])
        print(f'hcgs_pruning {cfg.reserve_num1}/8 {cfg.reserve_num2}/16')
        cfg.hcgs_mask = generate_hcgs_mask(model, cfg.block_shape, cfg.reserve_num1, cfg.reserve_num2)

    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" \
                and name.split(".")[-2] != "bn2" \
                and name.split(".")[-2] != "bn3" \
                and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [8,8])
        patterns = list(pattern_count_dict.keys())
        counts = list(pattern_count_dict.values())
        print(len(patterns))
        print(counts)
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=27)
    # loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)
    if args.test_acc == True:
        f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()

    if args.test_acc_cmodel == True:
        f1, val_loss, tps, preds, poses = test_acc_cmodel(val_loader, model, loss_fn)
        # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()
    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
示例#7
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.work_root   = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir     = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from
    cfg.save_excel   = args.save_excel

    if args.skip_exist == True:
        if os.path.exists(cfg.workdir):
            exit()

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True)

    # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True)
    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()


    name_list = list()
    para_list = list()
    for name, para in model.named_parameters():
        name_list.append(name)
        para_list.append(para)

    a = model.state_dict()
    for i, name in enumerate(name_list):
        if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
            raw_w = para_list[i]
            nn.init.xavier_normal_(raw_w, gain=1.0)
            a[name] = raw_w
    model.load_state_dict(a)
    

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    if args.vis_pattern == True or args.vis_mask == True:
        cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=False)
        print("loading", cfg.workdir + '/weights/best.pth')

    if os.path.exists(cfg.load_from):
        model.load_state_dict(torch.load(cfg.load_from), strict=False)
        print("loading", cfg.load_from)

    if os.path.exists(args.load_from_h5):
        # model.load_state_dict(torch.load(args.load_from_h5), strict=True)
        print("loading", args.load_from_h5)
        model.train()
        model_dict = model.state_dict()
        print(model_dict.keys())
        #先将参数值numpy转换为tensor形式
        pretrained_dict = dd.io.load(args.load_from_h5)
        print(pretrained_dict.keys())
        new_pre_dict = {}
        for k,v in pretrained_dict.items():
            new_pre_dict[k] = torch.Tensor(v)
        #更新
        model_dict.update(new_pre_dict)
        #加载
        model.load_state_dict(model_dict)

    if args.find_pattern == True:

        cfg.find_pattern_num   = 16
        cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])

        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        a = model.state_dict()
        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]
                if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                    patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \
                                    = find_pattern_by_similarity(raw_w
                                        , cfg.find_pattern_num
                                        , cfg.find_pattern_shape
                                        , cfg.find_zero_threshold
                                        , cfg.find_score_threshold)

                    pattern_num_memory_dict, pattern_num_coo_nnz_dict \
                                    = pattern_curve_analyse(raw_w.shape
                                        , cfg.find_pattern_shape
                                        , patterns
                                        , pattern_match_num
                                        , pattern_coo_nnz
                                        , pattern_nnz
                                        , pattern_inner_nnz)
                                        
                    write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                                        , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para
                                        , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz
                                        , pattern_num_memory_dict, pattern_num_coo_nnz_dict)

                    # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                    #                     , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para
                    #                     , all_nnzs.values(), all_patterns.values())
                    exit()



    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')

    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)

    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')


    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [8,8])
        patterns = list(pattern_count_dict.keys())
        counts = list(pattern_count_dict.values())
        print(len(patterns))
        print(counts)
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=27)
    # loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)
    if args.test_acc == True:
        f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()
    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
示例#8
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of aux feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    parser.add_argument("--feature_type",
                        default="world",
                        choices=["world", "melspc"],
                        type=str,
                        help="feature type")
    # network structure setting
    parser.add_argument("--n_quantize",
                        default=256,
                        type=int,
                        help="number of quantization")
    parser.add_argument("--n_aux",
                        default=28,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--n_resch",
                        default=512,
                        type=int,
                        help="number of channels of residual output")
    parser.add_argument("--n_skipch",
                        default=256,
                        type=int,
                        help="number of channels of skip output")
    parser.add_argument("--dilation_depth",
                        default=10,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=1,
                        type=int,
                        help="number of repeating of dilation")
    parser.add_argument("--kernel_size",
                        default=2,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor",
                        default=80,
                        type=int,
                        help="upsampling factor of aux features")
    parser.add_argument("--use_upsampling_layer",
                        default=True,
                        type=strtobool,
                        help="flag to use upsampling layer")
    parser.add_argument("--use_speaker_code",
                        default=False,
                        type=strtobool,
                        help="flag to use speaker code")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="weight decay coefficient")
    parser.add_argument(
        "--batch_length",
        default=20000,
        type=int,
        help="batch length (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--batch_size",
        default=1,
        type=int,
        help="batch size (if use utterance batch, batch_size will be 1.")
    parser.add_argument("--iters",
                        default=200000,
                        type=int,
                        help="number of iterations")
    # other setting
    parser.add_argument("--checkpoints",
                        default=10000,
                        type=int,
                        help="how frequent saving model")
    parser.add_argument("--intervals",
                        default=100,
                        type=int,
                        help="log interval")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        nargs="?",
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--n_gpus", default=1, type=int, help="number of gpus")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
    else:
        logging.basicConfig(
            level=logging.WARNING,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
        logging.warning("logging is disabled.")

    # show argmument
    for key, value in vars(args).items():
        logging.info("%s = %s" % (key, str(value)))

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # fix slow computation of dilated conv
    # https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923
    torch.backends.cudnn.benchmark = True

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    if args.use_upsampling_layer:
        upsampling_factor = args.upsampling_factor
    else:
        upsampling_factor = 0
    model = WaveNet(n_quantize=args.n_quantize,
                    n_aux=args.n_aux,
                    n_resch=args.n_resch,
                    n_skipch=args.n_skipch,
                    dilation_depth=args.dilation_depth,
                    dilation_repeat=args.dilation_repeat,
                    kernel_size=args.kernel_size,
                    upsampling_factor=upsampling_factor)
    logging.info(model)
    model.apply(initialize)
    model.train()

    if args.n_gpus > 1:
        device_ids = range(args.n_gpus)
        model = torch.nn.DataParallel(model, device_ids)
        model.receptive_field = model.module.receptive_field
        if args.n_gpus > args.batch_size:
            logging.warning("batch size is less than number of gpus.")

    # define optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    scaler = StandardScaler()
    scaler.mean_ = read_hdf5(args.stats, "/" + args.feature_type + "/mean")
    scaler.scale_ = read_hdf5(args.stats, "/" + args.feature_type + "/scale")
    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])
    feat_transform = transforms.Compose([lambda x: scaler.transform(x)])

    # define generator
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [
            args.feats + "/" + filename.replace(".wav", ".h5")
            for filename in filenames
        ]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    generator = train_generator(wav_list,
                                feat_list,
                                receptive_field=model.receptive_field,
                                batch_length=args.batch_length,
                                batch_size=args.batch_size,
                                feature_type=args.feature_type,
                                wav_transform=wav_transform,
                                feat_transform=feat_transform,
                                shuffle=True,
                                upsampling_factor=args.upsampling_factor,
                                use_upsampling_layer=args.use_upsampling_layer,
                                use_speaker_code=args.use_speaker_code)

    # charge minibatch in queue
    while not generator.queue.full():
        time.sleep(0.1)

    # resume model and optimizer
    if args.resume is not None and len(args.resume) != 0:
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        iterations = checkpoint["iterations"]
        if args.n_gpus > 1:
            model.module.load_state_dict(checkpoint["model"])
        else:
            model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        logging.info("restored from %d-iter checkpoint." % iterations)
    else:
        iterations = 0

    # check gpu and then send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
        for state in optimizer.state.values():
            for key, value in state.items():
                if torch.is_tensor(value):
                    state[key] = value.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    # train
    loss = 0
    total = 0
    for i in six.moves.range(iterations, args.iters):
        start = time.time()
        (batch_x, batch_h), batch_t = generator.next()
        batch_output = model(batch_x, batch_h)
        batch_loss = criterion(
            batch_output[:, model.receptive_field:].contiguous().view(
                -1, args.n_quantize),
            batch_t[:, model.receptive_field:].contiguous().view(-1))
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.item()
        total += time.time() - start
        logging.debug("batch loss = %.3f (%.3f sec / batch)" %
                      (batch_loss.item(), time.time() - start))

        # report progress
        if (i + 1) % args.intervals == 0:
            logging.info(
                "(iter:%d) average loss = %.6f (%.3f sec / batch)" %
                (i + 1, loss / args.intervals, total / args.intervals))
            logging.info(
                "estimated required time = "
                "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}".
                format(
                    relativedelta(seconds=int((args.iters - (i + 1)) *
                                              (total / args.intervals)))))
            loss = 0
            total = 0

        # save intermidiate model
        if (i + 1) % args.checkpoints == 0:
            if args.n_gpus > 1:
                save_checkpoint(args.expdir, model.module, optimizer, i + 1)
            else:
                save_checkpoint(args.expdir, model, optimizer, i + 1)

    # save final model
    if args.n_gpus > 1:
        torch.save({"model": model.module.state_dict()},
                   args.expdir + "/checkpoint-final.pkl")
    else:
        torch.save({"model": model.state_dict()},
                   args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
示例#9
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.workdir     = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True)

    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'))
        print("loading", cfg.workdir + '/weights/best.pth')

    if os.path.exists(cfg.load_from):
        model.load_state_dict(torch.load(cfg.load_from))
        print("loading", cfg.load_from)


    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')
    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')


    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [16,16])
        patterns = list(pattern_count_dict.keys())
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=0, reduction='none')

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)

    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)