class TensorboardLogger(Logger):
    def __init__(self,
                 log_interval=50,
                 validation_interval=200,
                 generate_interval=500,
                 trainer=None,
                 generate_function=None,
                 log_dir='logs'):
        super().__init__(log_interval, validation_interval, generate_interval,
                         trainer, generate_function)
        self.writer = SummaryWriter(log_dir)

    def log_loss(self, current_step):
        # loss
        avg_loss = self.accumulated_loss / self.log_interval
        self.writer.add_scalar('loss', avg_loss, current_step)

    def validate(self, current_step):
        avg_loss, avg_accuracy = self.trainer.validate()
        self.writer.add_scalar('validation/loss', avg_loss, current_step)
        self.writer.add_scalar('validation/accuracy', avg_accuracy,
                               current_step)

    def log_audio(self, step):
        samples = self.generate_function()
        self.writer.add_audio('audio sample', samples, step, sample_rate=16000)

    def image_summary(self, tag, images, step):
        """Log a list of images."""
        self.writer.add_images(tag, images, step)

    def audio_summary(self, tag, sample, step, sr=16000):
        self.writer.add_audio(tag, sample, sample_rate=sr)
Beispiel #2
0
class TensorboardHandler(object):
    def __init__(self, path="runs"):
        os.system(f'rm -rf {path}')
        self.writer = SummaryWriter(path)
        self._counter = 0
        self.models = []
        self.losses = []

    def update(self, models, losses, epoch=None):
        if epoch is None:
            epoch = self._counter
        apply_method(self, "add_model", models, epoch=epoch)
        apply_method(self, "add_loss", losses, epoch=epoch)
        self._counter += 1

    def add_model(self, model, epoch):
        # self.writer.add_graph(model, verbose=True)
        for n, p in model.named_parameters():
            self.writer.add_histogram('model/'+n,p.detach().cpu().numpy(),epoch)

    def add_loss(self, losses, epoch):
        for loss in losses:
            loss_names = list(loss.loss_history[list(loss.loss_history.keys())[0]].keys())
            for l in loss_names:
                # write last losses
                # pdb.set_trace()
                last_loss = {k: np.array(v[l]['values'][-1]) if len(v) > 0 else 0. for k, v in loss.loss_history.items()}
                self.writer.add_scalars('loss_%s/'%l, last_loss, epoch)

    def add_image(self, name, pic, epoch):
        self.writer.add_image(name, pic, epoch)

    def add_audio(self, name, audio, sr):
        self.writer.add_audio(name, audio, sample_rate=sr)
Beispiel #3
0
def validate(model: Model,
             val_loader: DataLoader,
             writer: SummaryWriter,
             iteration: int,
             waveglow_path: Path = None,
             fp16: bool = False):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(iteration % 42)
    metric_values, success_rates = [], []
    model.eval()
    start = time.time()

    # choose random mel-spectogram and synthesize audio from it
    random_idx = np.random.choice(len(val_loader), 1)
    for i, (text, input_lengths, mel, stop_target) in enumerate(val_loader):
        text = text.to("cuda", non_blocking=True)
        input_lengths = input_lengths.to("cuda", non_blocking=True)
        mel = mel.to("cuda", non_blocking=True)
        stop_target = stop_target.to("cuda", non_blocking=True)
        with torch.no_grad():
            mel_pred, mel_pred_postnet, stop_predictions, alignment = model(
                text, input_lengths, mel)
            mel_loss = F.mse_loss(mel_pred, mel)
            mel_postnet_loss = F.mse_loss(mel_pred_postnet, mel)
            stop_loss = F.binary_cross_entropy(stop_predictions, stop_target)

            loss = mel_loss + mel_postnet_loss + stop_loss
            loss_value = loss.item()
            metric_values.append(loss_value)
            success_rates.append(success_rate(alignment))
            if i == random_idx:
                mel_to_gen = mel_pred_postnet[0]
                gen_alignment = alignment[0]
    avg_loss = np.mean(metric_values)
    avg_success_rate = np.mean(success_rates)
    writer.add_scalar('loss/validation', avg_loss, iteration)
    writer.add_scalar('success_rate/validation', avg_success_rate, iteration)
    writer.add_image("mel_pred/validation",
                     show_figure(mel_to_gen.float().cpu().numpy()),
                     iteration,
                     dataformats='HWC')
    writer.add_image("alignment/validation",
                     show_figure(gen_alignment.float().cpu().numpy(),
                                 origin='lower'),
                     iteration,
                     dataformats='HWC')
    if waveglow_path:
        print('start to audio synthesis')
        writer.add_audio("audio/validation",
                         waveglow_gen(waveglow_path,
                                      mel_to_gen[None],
                                      fp16=fp16),
                         iteration,
                         sample_rate=22050)
    end = time.time()
    print(
        f"validation {iteration}, loss={loss_value:.3f}, {end - start:.2f} s.")
    return avg_loss
Beispiel #4
0
class TensorLogger(object):
    # creating file in given logdir ... defaults to ./runs/
    def __init__(self, _logdir='./runs/'):
        if not os.path.exists(_logdir):
            os.makedirs(_logdir)

        self.writer = SummaryWriter(log_dir=_logdir)

    # adding scalar value to tb file
    def scalar_summary(self, _tag, _value, _step):
        self.writer.add_scalar(_tag, _value, _step)

    # adding image value to tb file
    def image_summary(self, _tag, _image, _step, _format='CHW'):
        """
            default dataformat for image tensor is (3, H, W)
            can be changed to
                : (1, H, W) - dataformat = CHW
                : (H, W, 3) - dataformat HWC
                : (H, W) - datformat HW
        """
        #
        self.writer.add_image(_tag, _image, _step, dataformat=_format)

    # adding matplotlib figure to tb file
    def figure_summary(self, _tag, _figure, _step):
        self.writer.add_figure(_tag, _figure, _step)

    # adding video to tb file
    def video_summary(self, _tag, _video, _step, _fps=4):
        """
            default torch fps is 4, can be changed
            also, video tensor should be of format (N, T, C, H, W)
            values should be between [0,255] for unit8 and [0,1] for float32
        """
        # default value of video fps is 4 - can be changed
        self.writer.add_video(_tag, _video, _step, _fps)

    # adding audio to tb file
    def audio_summary(self, _tag, _sound, _step, _sampleRate=44100):
        """
            default torch sample rate is 44100, can be changed
            also, sound tensor should be of format (1, L)
            values should lie between [-1,1]
        """
        self.writer.add_audio(_tag, _sound, _step, sample_rate=_sampleRate)

    # adding text to tb file
    def text_summary(self, _tag, _textString, _step):
        self.writer.add_text(_tag, _textString, _step)

    # adding histograms to tb file
    def histogram_summary(self, _tag, _histogram, _step, _bins='tensorflow'):
        self.writer.add_histrogram(_tag, _histogram, _step, bins=_bins)
Beispiel #5
0
class TensorBoardLogger(Logger):
    def __init__(self, *args, **kwargs):
        super().__init__()

        path = Path(kwargs['log_dir'])
        path.mkdir(parents=True, exist_ok=True)
        self.log_dir = path
        self.writer = SummaryWriter(path)
        self.track_info = kwargs.get('track_info', None)
        if self.track_info is not None:
            self.generate_multitrack = partial(generate_multitrack,
                                               **self.track_info)

    def add_scalar(self, name, value, step):
        self.writer.add_scalar(name, value, global_step=step)

    def add_image(self, name, value, step):
        # value is img_tensor with dim (3, H, W) or (1, H, W)
        self.writer.add_image(name, value, global_step=step)

    def add_figure(self, name, fig, step):
        self.writer.add_figure(name, fig, global_step=step)

    def add_audio(self, name, value, sample_rate, step):
        self.writer.add_audio(name,
                              value,
                              global_step=step,
                              sample_rate=sample_rate)

    def add_histogram(self, name, value, step):
        self.writer.add_histogram(name, value, global_step=step)

    def add_pianoroll_img(self, name, pianoroll, step):
        for i, instrument in enumerate(self.track_info['instruments']):
            # pianoroll has dim (seq_length, instruments, n_pitches)
            self.add_image("{}_{}".format(name, instrument),
                           np.expand_dims(pianoroll[:, 0, :].T, 0), step)

    def add_pianoroll_audio(self, name, pianoroll, step):
        import librosa
        from midi2audio import FluidSynth

        midi_file = str(self.log_dir / "tmp.mid")
        wav_file = str(self.log_dir / "tmp.wav")

        multitrack = self.generate_multitrack(pianoroll)
        multitrack.write(midi_file)
        FluidSynth().midi_to_audio(midi_file, wav_file)
        y, sr = librosa.load(wav_file)
        self.add_audio(name, y, sr, step)

    def close(self):
        self.writer.close()
class VisualizerTensorboard:
    def __init__(self, opts):
        self.dtype = {}
        self.iteration = 1
        self.writer = SummaryWriter(opts.logs_dir)

    def register(self, modules):
        # here modules are assumed to be a dictionary
        for key in modules:
            self.dtype[key] = modules[key]['dtype']

    def update(self, modules):
        for key, value in modules:
            if self.dtype[key] == 'scalar':
                self.writer.add_scalar(key, value, self.iteration)
            elif self.dtype[key] == 'scalars':
                self.writer.add_scalars(key, value, self.iteration)
            elif self.dtype[key] == 'histogram':
                self.writer.add_histogram(key, value, self.iteration)
            elif self.dtype[key] == 'image':
                self.writer.add_image(key, value, self.iteration)
            elif self.dtype[key] == 'images':
                self.writer.add_images(key, value, self.iteration)
            elif self.dtype[key] == 'figure':
                self.writer.add_figure(key, value, self.iteration)
            elif self.dtype[key] == 'video':
                self.writer.add_video(key, value, self.iteration)
            elif self.dtype[key] == 'audio':
                self.writer.add_audio(key, value, self.iteration)
            elif self.dtype[key] == 'text':
                self.writer.add_text(key, value, self.iteration)
            elif self.dtype[key] == 'embedding':
                self.writer.add_embedding(key, value, self.iteration)
            elif self.dtype[key] == 'pr_curve':
                self.writer.pr_curve(key, value['labels'],
                                     value['predictions'], self.iteration)
            elif self.dtype[key] == 'mesh':
                self.writer.add_audio(key, value, self.iteration)
            elif self.dtype[key] == 'hparams':
                self.writer.add_hparams(key, value['hparam_dict'],
                                        value['metric_dict'], self.iteration)
            else:
                raise Exception(
                    'Data type not supported, please update the visualizer plugin and rerun !!'
                )

        self.iteration = self.iteration + 1
Beispiel #7
0
class SummaryWriter:
    def __init__(self, logdir, flush_secs=120):

        self.writer = TensorboardSummaryWriter(
            log_dir=logdir,
            purge_step=None,
            max_queue=10,
            flush_secs=flush_secs,
            filename_suffix='')

        self.global_step = None
        self.active = True

        # ------------------------------------------------------------------------
        # register add_* and set_* functions in summary module on instantiation
        # ------------------------------------------------------------------------
        this_module = sys.modules[__name__]
        list_of_names = dir(SummaryWriter)
        for name in list_of_names:

            # add functions (without the 'add' prefix)
            if name.startswith('add_'):
                setattr(this_module, name[4:], getattr(self, name))

            #  set functions
            if name.startswith('set_'):
                setattr(this_module, name, getattr(self, name))

    def set_global_step(self, value):
        self.global_step = value

    def set_active(self, value):
        self.active = value

    def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_audio(
                tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime)

    def add_custom_scalars(self, layout):
        if self.active:
            self.writer.add_custom_scalars(layout)

    def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_marginchart(tags, category=category, title=title)

    def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title)

    def add_embedding(self, mat, metadata=None, label_img=None, global_step=None,
                      tag='default', metadata_header=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_embedding(
                mat, metadata=metadata, label_img=label_img, global_step=global_step,
                tag=tag, metadata_header=metadata_header)

    def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_figure(
                tag, figure, global_step=global_step, close=close, walltime=walltime)

    def add_graph(self, model, input_to_model=None, verbose=False):
        if self.active:
            self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose)

    def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram(
                tag, values, global_step=global_step, bins=bins,
                walltime=walltime, max_bins=max_bins)

    def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,
                          bucket_limits, bucket_counts, global_step=None,
                          walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram_raw(
                tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares,
                bucket_limits=bucket_limits, bucket_counts=bucket_counts,
                global_step=global_step, walltime=walltime)

    def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
                             walltime=None, rescale=1, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image_with_boxes(
                tag, img_tensor, box_tensor,
                global_step=global_step, walltime=walltime,
                rescale=rescale, dataformats=dataformats)

    def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_images(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_mesh(
                tag, vertices, colors=colors, faces=faces, config_dict=config_dict,
                global_step=global_step, walltime=walltime)

    def add_onnx_graph(self, graph):
        if self.active:
            self.writer.add_onnx_graph(graph)

    def add_pr_curve(self, tag, labels, predictions, global_step=None,
                     num_thresholds=127, weights=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve(
                tag, labels, predictions, global_step=global_step,
                num_thresholds=num_thresholds, weights=weights, walltime=walltime)

    def add_pr_curve_raw(self, tag, true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall,
                         global_step=None,
                         num_thresholds=127,
                         weights=None,
                         walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve_raw(
                tag, true_positive_counts,
                false_positive_counts,
                true_negative_counts,
                false_negative_counts,
                precision,
                recall,
                global_step=global_step,
                num_thresholds=num_thresholds,
                weights=weights,
                walltime=walltime)

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalar(
                tag, scalar_value, global_step=global_step, walltime=walltime)

    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalars(
                main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)

    def add_text(self, tag, text_string, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_text(
                tag, text_string, global_step=global_step, walltime=walltime)

    def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_video(
                tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime)

    def close(self):
        self.writer.close()

    def __enter__(self):
        return self.writer.__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self.writer.__exit__(exc_type, exc_val, exc_tb)
Beispiel #8
0
def train_melgan(args):
    # args = parse_args()

    root = Path(args.save_path)
    load_root = Path(args.load_path) if args.load_path else None
    root.mkdir(parents=True, exist_ok=True)

    ####################################
    # Dump arguments and create logger #
    ####################################
    # with open(root / "args.yml", "w") as f:
    #     yaml.dump(args, f)
    with open(root / "args.json", "w", encoding="utf8") as f:
        json.dump(args.__dict__, f, indent=4, ensure_ascii=False)
    eventdir = root / "events"
    eventdir.mkdir(exist_ok=True)
    writer = SummaryWriter(str(eventdir))

    #######################
    # Load PyTorch Models #
    #######################
    ratios = [int(w) for w in args.ratios.split()]
    netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers, ratios=ratios).to(_device)
    netD = Discriminator(
        args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor
    ).to(_device)
    # fft = Audio2Mel(n_mel_channels=args.n_mel_channels).to(_device)
    if args.mode == 'default':
        fft = audio2mel
    elif args.mode == 'synthesizer':
        fft = audio2mel_synthesizer
    elif args.mode == 'mellotron':
        fft = audio2mel_mellotron
    else:
        raise KeyError
    # print(netG)
    # print(netD)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

    if load_root and load_root.exists():
        netG.load_state_dict(torch.load(load_root))
        # optG.load_state_dict(torch.load(load_root / "optG.pt"))
        # netD.load_state_dict(torch.load(load_root / "netD.pt"))
        # optD.load_state_dict(torch.load(load_root / "optD.pt"))

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(
        Path(args.data_path), args.seq_len, sampling_rate=args.sample_rate
    )
    test_set = AudioDataset(
        Path(args.data_path),  # test file
        args.sample_rate * 4,
        sampling_rate=args.sample_rate,
        augment=False,
    )

    train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=1)

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.to(_device)
        s_t = fft(x_t).detach()

        test_voc.append(s_t.to(_device))
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        oridir = root / "original"
        oridir.mkdir(exist_ok=True)
        save_sample(oridir / ("original_{}_{}.wav".format("test", i)), args.sample_rate, audio)
        writer.add_audio("original/{}/sample_{}.wav".format("test", i), audio, 0, sample_rate=args.sample_rate)

        if i == args.n_test_samples - 1:
            break

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000
    step_begin = args.start_step
    look_steps = {step_begin + 10, step_begin + 100, step_begin + 1000, step_begin + 10000}
    steps = step_begin
    for epoch in range(1, args.epochs + 1):
        print("\nEpoch {} beginning. Current step: {}".format(epoch, steps))
        for iterno, x_t in enumerate(tqdm(train_loader, desc="iter", ncols=100)):
            # torch.Size([4, 1, 8192]) torch.Size([4, 80, 32])
            # 8192 = 32 x 256
            x_t = x_t.to(_device)
            s_t = fft(x_t).detach()
            x_pred_t = netG(s_t.to(_device))

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.to(_device).detach())
            D_real = netD(x_t.to(_device))

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.to(_device))

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (args.n_layers_D + 1)
            D_weights = 1.0 / args.num_D
            wt = D_weights * feat_weights
            for i in range(args.num_D):
                for j in range(len(D_fake[i]) - 1):
                    loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + args.lambda_feat * loss_feat).backward()
            optG.step()

            ######################
            # Update tensorboard #
            ######################

            costs.append([loss_D.item(), loss_G.item(), loss_feat.item(), s_error])
            steps += 1
            writer.add_scalar("loss/discriminator", costs[-1][0], steps)
            writer.add_scalar("loss/generator", costs[-1][1], steps)
            writer.add_scalar("loss/feature_matching", costs[-1][2], steps)
            writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps)

            if steps % args.save_interval == 0 or steps in look_steps:
                st = time.time()
                with torch.no_grad():
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        gendir = root / "generated"
                        gendir.mkdir(exist_ok=True)
                        save_sample(gendir / ("generated_step{}_{}.wav".format(steps, i)), args.sample_rate, pred_audio)
                        writer.add_audio(
                            "generated/step{}/sample_{}.wav".format(steps, i),
                            pred_audio,
                            epoch,
                            sample_rate=args.sample_rate,
                        )

                ptdir = root / "models"
                ptdir.mkdir(exist_ok=True)
                torch.save(netG.state_dict(), ptdir / "step{}_netG.pt".format(steps))
                torch.save(optG.state_dict(), ptdir / "step{}_optG.pt".format(steps))

                torch.save(netD.state_dict(), ptdir / "step{}_netD.pt".format(steps))
                torch.save(optD.state_dict(), ptdir / "step{}_optD.pt".format(steps))

                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netD.state_dict(), ptdir / "best_step{}_netD.pt".format(steps))
                    torch.save(netG.state_dict(), ptdir / "best_step{}_netG.pt".format(steps))
                # print("\nTook %5.4fs to generate samples" % (time.time() - st))
                # print("-" * 100)

            if steps % args.log_interval == 0 or steps in look_steps:
                print(
                    "\nEpoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".format(
                        epoch,
                        iterno,
                        len(train_loader),
                        1000 * (time.time() - start) / args.log_interval,
                        np.asarray(costs).mean(0),
                    )
                )
                costs = []
                start = time.time()
class VocTrainer:
    def __init__(self, paths: Paths, dsp: DSP, config: Dict[str, Any]) -> None:
        self.paths = paths
        self.writer = SummaryWriter(log_dir=paths.voc_log, comment='v1')
        self.dsp = dsp
        self.config = config
        self.train_cfg = config['vocoder']['training']
        self.loss_func = F.cross_entropy if self.dsp.voc_mode == 'RAW' else discretized_mix_logistic_loss
        path_top_k = paths.voc_top_k / 'top_k.pkl'
        if os.path.exists(path_top_k):
            self.top_k_models = unpickle_binary(path_top_k)
            # log recent top models
            for i, (mel_loss, g_wav, m_step,
                    m_name) in enumerate(self.top_k_models, 1):
                self.writer.add_audio(tag=f'Top_K_Models/generated_top_{i}',
                                      snd_tensor=g_wav,
                                      global_step=m_step,
                                      sample_rate=self.dsp.sample_rate)
        else:
            self.top_k_models = []

    def train(self,
              model: WaveRNN,
              optimizer: Optimizer,
              train_gta=False) -> None:
        voc_schedule = self.train_cfg['schedule']
        voc_schedule = parse_schedule(voc_schedule)
        for i, session_params in enumerate(voc_schedule, 1):
            lr, max_step, bs = session_params
            if model.get_step() < max_step:
                train_set, val_set, val_set_samples = get_vocoder_datasets(
                    path=self.paths.data,
                    batch_size=bs,
                    train_gta=train_gta,
                    max_mel_len=self.train_cfg['max_mel_len'],
                    hop_length=self.dsp.hop_length,
                    voc_pad=model.pad,
                    voc_seq_len=self.train_cfg['seq_len'],
                    voc_mode=self.dsp.voc_mode,
                    bits=self.dsp.bits,
                    num_gen_samples=self.train_cfg['num_gen_samples'])
                session = VocSession(index=i,
                                     lr=lr,
                                     max_step=max_step,
                                     bs=bs,
                                     train_set=train_set,
                                     val_set=val_set,
                                     val_set_samples=val_set_samples)
                self.train_session(model, optimizer, session, train_gta)

    def train_session(self, model: WaveRNN, optimizer: Optimizer,
                      session: VocSession, train_gta: bool) -> None:
        current_step = model.get_step()
        training_steps = session.max_step - current_step
        total_iters = len(session.train_set)
        epochs = training_steps // total_iters + 1
        simple_table([(f'Steps ', str(training_steps // 1000) + 'k'),
                      ('Batch Size', session.bs),
                      ('Learning Rate', session.lr),
                      ('Sequence Length', self.train_cfg['seq_len']),
                      ('GTA Training', train_gta)])
        for g in optimizer.param_groups:
            g['lr'] = session.lr

        loss_avg = Averager()
        duration_avg = Averager()
        device = next(
            model.parameters()).device  # use same device as model parameters

        for e in range(1, epochs + 1):
            for i, batch in enumerate(session.train_set, 1):
                start = time.time()
                model.train()
                batch = to_device(batch, device=device)
                x, y = batch['x'], batch['y']
                y_hat = model(x, batch['mel'])
                if model.mode == 'RAW':
                    y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                elif model.mode == 'MOL':
                    y = batch['y'].float()
                y = y.unsqueeze(-1)

                loss = self.loss_func(y_hat, y)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.train_cfg['clip_grad_norm'])
                optimizer.step()
                loss_avg.add(loss.item())
                step = model.get_step()
                k = step // 1000

                duration_avg.add(time.time() - start)
                speed = 1. / duration_avg.get()
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \
                      f'| {speed:#.2} steps/s | Step: {k}k | '

                if step % self.train_cfg['gen_samples_every'] == 0:
                    stream(msg + 'generating samples...')
                    gen_result = self.generate_samples(model, session)
                    if gen_result is not None:
                        mel_loss, gen_wav = gen_result
                        self.writer.add_scalar('Loss/generated_mel_l1',
                                               mel_loss, model.get_step())
                        self.track_top_models(mel_loss, gen_wav, model)

                if step % self.train_cfg['checkpoint_every'] == 0:
                    save_checkpoint(model=model,
                                    optim=optimizer,
                                    config=self.config,
                                    path=self.paths.voc_checkpoints /
                                    f'wavernn_step{k}k.pt')

                self.writer.add_scalar('Loss/train', loss, model.get_step())
                self.writer.add_scalar('Params/batch_size', session.bs,
                                       model.get_step())
                self.writer.add_scalar('Params/learning_rate', session.lr,
                                       model.get_step())

                stream(msg)

            val_loss = self.evaluate(model, session.val_set)
            self.writer.add_scalar('Loss/val', val_loss, model.get_step())
            save_checkpoint(model=model,
                            optim=optimizer,
                            config=self.config,
                            path=self.paths.voc_checkpoints /
                            'latest_model.pt')

            loss_avg.reset()
            duration_avg.reset()
            print(' ')

    def evaluate(self, model: WaveRNN, val_set: Dataset) -> float:
        model.eval()
        val_loss = 0
        device = next(model.parameters()).device
        for i, batch in enumerate(val_set, 1):
            batch = to_device(batch, device=device)
            x, y, m = batch['x'], batch['y'], batch['mel']
            with torch.no_grad():
                y_hat = model(x, m)
                if model.mode == 'RAW':
                    y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                elif model.mode == 'MOL':
                    y = y.float()
                y = y.unsqueeze(-1)
                loss = self.loss_func(y_hat, y)
                val_loss += loss.item()
        return val_loss / len(val_set)

    @ignore_exception
    def generate_samples(self, model: WaveRNN,
                         session: VocSession) -> Tuple[float, list]:
        """
        Generates audio samples to cherry-pick models. To evaluate audio quality
        we calculate the l1 distance between mels of predictions and targets.
        """
        model.eval()
        mel_losses = []
        gen_wavs = []
        device = next(model.parameters()).device
        for i, sample in enumerate(session.val_set_samples, 1):
            m, x = sample['mel'], sample['x']
            if i > self.train_cfg['num_gen_samples']:
                break
            x = x[0].numpy()
            bits = 16 if self.dsp.voc_mode == 'MOL' else self.dsp.bits
            if self.dsp.mu_law and self.dsp.voc_mode != 'MOL':
                x = DSP.decode_mu_law(x, 2**bits, from_labels=True)
            else:
                x = DSP.label_2_float(x, bits)
            gen_wav = model.generate(mels=m,
                                     batched=self.train_cfg['gen_batched'],
                                     target=self.train_cfg['target'],
                                     overlap=self.train_cfg['overlap'],
                                     mu_law=self.dsp.mu_law,
                                     silent=True)

            gen_wavs.append(gen_wav)
            y_mel = self.dsp.wav_to_mel(x.squeeze(), normalize=False)
            y_mel = torch.tensor(y_mel).to(device)
            y_hat_mel = self.dsp.wav_to_mel(gen_wav, normalize=False)
            y_hat_mel = torch.tensor(y_hat_mel).to(device)
            loss = F.l1_loss(y_hat_mel, y_mel)
            mel_losses.append(loss.item())

            self.writer.add_audio(tag=f'Validation_Samples/target_{i}',
                                  snd_tensor=x,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)
            self.writer.add_audio(tag=f'Validation_Samples/generated_{i}',
                                  snd_tensor=gen_wav,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)

        return sum(mel_losses) / len(mel_losses), gen_wavs[0]

    def track_top_models(self, mel_loss, gen_wav, model):
        """ Keeps track of top k models and saves them according to their current rank """
        for j, (l, g, m, m_n) in enumerate(self.top_k_models):
            print(f'{j} {l} {m} {m_n}')
        if len(self.top_k_models) < self.train_cfg[
                'keep_top_k'] or mel_loss < self.top_k_models[-1][0]:
            m_step = model.get_step()
            model_name = f'model_loss{mel_loss:#0.5}_step{m_step}_weights.pyt'
            self.top_k_models.append(
                (mel_loss, gen_wav, model.get_step(), model_name))
            self.top_k_models.sort(key=lambda t: t[0])
            self.top_k_models = self.top_k_models[:self.
                                                  train_cfg['keep_top_k']]
            model.save(self.paths.voc_top_k / model_name)
            all_models = get_files(self.paths.voc_top_k, extension='pyt')
            top_k_names = {m[-1] for m in self.top_k_models}
            for model_file in all_models:
                if model_file.name not in top_k_names:
                    print(f'removing {model_file}')
                    os.remove(model_file)
            pickle_binary(self.top_k_models,
                          self.paths.voc_top_k / 'top_k.pkl')

            for i, (mel_loss, g_wav, m_step,
                    m_name) in enumerate(self.top_k_models, 1):
                self.writer.add_audio(tag=f'Top_K_Models/generated_top_{i}',
                                      snd_tensor=g_wav,
                                      global_step=m_step,
                                      sample_rate=self.dsp.sample_rate)
Beispiel #10
0
def train(rank, a, h):
    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'],
                           init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus,
                           rank=rank)

    torch.cuda.manual_seed(h.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(h).to(device)
    mpd = MultiPeriodDiscriminator(
        h["discriminator_periods"] if "discriminator_periods" in
        h.keys() else None).to(device)
    msd = MultiScaleDiscriminator().to(device)

    if rank == 0:
        print(generator)
        os.makedirs(a.checkpoint_path, exist_ok=True)
        print("checkpoints directory : ", a.checkpoint_path)

    if os.path.isdir(a.checkpoint_path):
        cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
        cp_do = scan_checkpoint(a.checkpoint_path, 'do_')

    steps = 0
    if cp_g is not None:
        state_dict_g = load_checkpoint(cp_g, device)
        gsd = generator.state_dict()
        gsd.update({
            k: v
            for k, v in state_dict_g['generator'].items()
            if k in gsd and state_dict_g['generator'][k].shape == gsd[k].shape
        })
        missing_keys = {
            k: v
            for k, v in state_dict_g['generator'].items()
            if not (k in gsd
                    and state_dict_g['generator'][k].shape == gsd[k].shape)
        }.keys()
        generator.load_state_dict(gsd)
        del gsd, state_dict_g

    if cp_do is None or len(missing_keys) or a.from_zero:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_do = load_checkpoint(cp_do, device)
        mpd.load_state_dict(state_dict_do['mpd'])
        del state_dict_do['mpd']
        msd.load_state_dict(state_dict_do['msd'])
        del state_dict_do['msd']
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if h.num_gpus > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(generator.parameters(),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(),
                                                mpd.parameters()),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])
        del state_dict_do

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(
        a, h.segment_size, h.sampling_rate)

    trainset = MelDataset(training_filelist,
                          h.segment_size,
                          h.n_fft,
                          h.num_mels,
                          h.hop_size,
                          h.win_size,
                          h.sampling_rate,
                          h.fmin,
                          h.fmax,
                          n_cache_reuse=0,
                          shuffle=False if h.num_gpus > 1 else True,
                          fmax_loss=h.fmax_for_loss,
                          device=device,
                          fine_tuning=a.fine_tuning,
                          trim_non_voiced=a.trim_non_voiced)

    STFT = STFT_Class(h.sampling_rate, h.num_mels, h.n_fft, h.win_size,
                      h.hop_size, h.fmin, h.fmax)

    train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=h.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=h.batch_size,
                              pin_memory=True,
                              drop_last=True)
    assert len(train_loader), 'No audio files in dataset!'

    if rank == 0:
        validset = MelDataset(validation_filelist,
                              h.segment_size,
                              h.n_fft,
                              h.num_mels,
                              h.hop_size,
                              h.win_size,
                              h.sampling_rate,
                              h.fmin,
                              h.fmax,
                              False,
                              False,
                              n_cache_reuse=0,
                              fmax_loss=h.fmax_for_loss,
                              device=device,
                              fine_tuning=a.fine_tuning,
                              trim_non_voiced=a.trim_non_voiced)
        validation_loader = DataLoader(validset,
                                       num_workers=h.num_workers,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'),
                           max_queue=10000)
        sw.logged_gt_plots = False

    if h.num_gpus > 1:
        import gc
        gc.collect()
        torch.cuda.empty_cache()

    generator.train()
    mpd.train()
    msd.train()
    for epoch in range(max(0, last_epoch), a.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if h.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            x, y, _, y_mel = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel = torch.autograd.Variable(y_mel.to(device,
                                                     non_blocking=True))
            y = y.unsqueeze(1)

            y_g_hat = generator(x)
            y_g_hat_mel = STFT.get_mel(y_g_hat.squeeze(1))

            optim_d.zero_grad()

            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
            loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
                y_df_hat_r, y_df_hat_g)

            # MSD
            y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
            loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
                y_ds_hat_r, y_ds_hat_g)

            loss_disc_all = loss_disc_s + loss_disc_f

            loss_disc_all.backward()
            optim_d.step()

            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            loss_mel = F.l1_loss(y_mel, y_g_hat_mel)

            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
            loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel * 45

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                torch.set_grad_enabled(False)
                # STDOUT logging
                if steps % a.stdout_interval == 0:
                    print(
                        'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'
                        .format(steps, loss_gen_all, loss_mel.item(),
                                time.time() - start_b))

                # checkpointing
                if steps % a.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'generator': (generator.module if h.num_gpus > 1
                                          else generator).state_dict()
                        })
                    checkpoint_path = "{}/do_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'mpd': (mpd.module
                                    if h.num_gpus > 1 else mpd).state_dict(),
                            'msd': (msd.module
                                    if h.num_gpus > 1 else msd).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })
                    del_old_checkpoints(a.checkpoint_path, 'g_',
                                        a.n_models_to_keep)
                    del_old_checkpoints(a.checkpoint_path, 'do_',
                                        a.n_models_to_keep)

                # Tensorboard summary logging
                if steps % a.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all,
                                  steps)
                    sw.add_scalar("training/mel_spec_error", loss_mel.item(),
                                  steps)

                # Validation
                if steps % a.validation_interval == 0:  # and steps != 0:
                    print("Validating...")
                    n_audios_to_plot = 6
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    for j, batch in tqdm(enumerate(validation_loader),
                                         total=len(validation_loader)):
                        x, y, _, y_mel = batch
                        y_g_hat = generator(x.to(device))
                        y_hat_spec = STFT.get_mel(y_g_hat.squeeze(1))
                        val_err_tot += F.l1_loss(y_mel,
                                                 y_hat_spec.to(y_mel)).item()

                        if j < n_audios_to_plot and not sw.logged_gt_plots:
                            sw.add_audio(f'gt/y_{j}', y[0], steps,
                                         h.sampling_rate)
                            sw.add_figure(f'spec_{j:02}/gt_spec',
                                          plot_spectrogram(y_mel[0]), steps)
                        if j < n_audios_to_plot:
                            sw.add_audio(f'generated/y_hat_{j}', y_g_hat[0],
                                         steps, h.sampling_rate)
                            sw.add_figure(
                                f'spec_{j:02}/pred_spec',
                                plot_spectrogram(
                                    y_hat_spec.squeeze(0).cpu().numpy()),
                                steps)

                        if j > 64:  # I am NOT patient enough to complete an entire validation cycle with over 1536 files.
                            break
                    sw.logged_gt_plots = True
                    val_err = val_err_tot / (j + 1)
                    sw.add_scalar("validation/mel_spec_error", val_err, steps)
                    generator.train()
                    print(f"Done. Val_loss = {val_err}")
                torch.set_grad_enabled(True)
            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Beispiel #11
0
def main(args):
    #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5

    # MODEL
    num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
                   [args.features*2**i for i in range(0, args.levels)]
    target_outputs = int(args.output_size * args.sr)
    model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
                     target_output_size=target_outputs, depth=args.depth, strides=args.strides,
                     conv_type=args.conv_type, res=args.res, separate=args.separate)

    if args.cuda:
        model = utils.DataParallel(model)
        print("move model to gpu")
        model.cuda()

    print('model: ', model)
    print('parameter count: ', str(sum(p.numel() for p in model.parameters())))

    writer = SummaryWriter(args.log_dir)

    ### DATASET
    musdb = get_musdb_folds(args.dataset_dir)
    # If not data augmentation, at least crop targets to fit model output shape
    crop_func = partial(crop, shapes=model.shapes)
    # Data augmentation function for training
    augment_func = partial(random_amplify, shapes=model.shapes, min=0.7, max=1.0)
    train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, model.shapes, True, args.hdf_dir, audio_transform=augment_func)
    val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func)
    test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func)

    dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn)

    ##### TRAINING ####

    # Set up the loss function
    if args.loss == "L1":
        criterion = nn.L1Loss()
    elif args.loss == "L2":
        criterion = nn.MSELoss()
    else:
        raise NotImplementedError("Couldn't find this loss!")

    # Set up optimiser
    optimizer = Adam(params=model.parameters(), lr=args.lr)

    # Set up training state dict that will also be saved into checkpoints
    state = {"step" : 0,
             "worse_epochs" : 0,
             "epochs" : 0,
             "best_loss" : np.Inf}

    # LOAD MODEL CHECKPOINT IF DESIRED
    if args.load_model is not None:
        print("Continuing training full model from checkpoint " + str(args.load_model))
        state = utils.load_model(model, optimizer, args.load_model)

    print('TRAINING START')
    while state["worse_epochs"] < args.patience:
        print("Training one epoch from iteration " + str(state["step"]))
        avg_time = 0.
        model.train()
        with tqdm(total=len(train_data) // args.batch_size) as pbar:
            np.random.seed()
            for example_num, (x, targets) in enumerate(dataloader):
                if args.cuda:
                    x = x.cuda()
                    for k in list(targets.keys()):
                        targets[k] = targets[k].cuda()

                t = time.time()

                # Set LR for this iteration
                utils.set_cyclic_lr(optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr)
                writer.add_scalar("lr", utils.get_lr(optimizer), state["step"])

                # Compute loss for each instrument/model
                optimizer.zero_grad()
                outputs, avg_loss = utils.compute_loss(model, x, targets, criterion, compute_grad=True)

                optimizer.step()

                state["step"] += 1

                t = time.time() - t
                avg_time += (1. / float(example_num + 1)) * (t - avg_time)

                writer.add_scalar("train_loss", avg_loss, state["step"])

                if example_num % args.example_freq == 0:
                    input_centre = torch.mean(x[0, :, model.shapes["output_start_frame"]:model.shapes["output_end_frame"]], 0) # Stereo not supported for logs yet
                    writer.add_audio("input", input_centre, state["step"], sample_rate=args.sr)

                    for inst in outputs.keys():
                        writer.add_audio(inst + "_pred", torch.mean(outputs[inst][0], 0), state["step"], sample_rate=args.sr)
                        writer.add_audio(inst + "_target", torch.mean(targets[inst][0], 0), state["step"], sample_rate=args.sr)

                pbar.update(1)

        # VALIDATE
        val_loss = validate(args, model, criterion, val_data)
        print("VALIDATION FINISHED: LOSS: " + str(val_loss))
        writer.add_scalar("val_loss", val_loss, state["step"])

        # EARLY STOPPING CHECK
        checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint_" + str(state["step"]))
        if val_loss >= state["best_loss"]:
            state["worse_epochs"] += 1
        else:
            print("MODEL IMPROVED ON VALIDATION SET!")
            state["worse_epochs"] = 0
            state["best_loss"] = val_loss
            state["best_checkpoint"] = checkpoint_path

        # CHECKPOINT
        print("Saving model...")
        utils.save_model(model, optimizer, state, checkpoint_path)

        state["epochs"] += 1

    #### TESTING ####
    # Test loss
    print("TESTING")

    # Load best model based on validation loss
    state = utils.load_model(model, None, state["best_checkpoint"])
    test_loss = validate(args, model, criterion, test_data)
    print("TEST FINISHED: LOSS: " + str(test_loss))
    writer.add_scalar("test_loss", test_loss, state["step"])

    # Mir_eval metrics
    test_metrics = evaluate(args, musdb["test"], model, args.instruments)

    # Dump all metrics results into pickle file for later analysis if needed
    with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f:
        pickle.dump(test_metrics, f)

    # Write most important metrics into Tensorboard log
    avg_SDRs = {inst : np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments}
    avg_SIRs = {inst : np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments}
    for inst in args.instruments:
        writer.add_scalar("test_SDR_" + inst, avg_SDRs[inst], state["step"])
        writer.add_scalar("test_SIR_" + inst, avg_SIRs[inst], state["step"])
    overall_SDR = np.mean([v for v in avg_SDRs.values()])
    writer.add_scalar("test_SDR", overall_SDR)
    print("SDR: " + str(overall_SDR))

    writer.close()
Beispiel #12
0
class TacoTrainer:
    def __init__(self, paths: Paths) -> None:
        self.paths = paths
        self.writer = SummaryWriter(log_dir=paths.tts_log, comment='v1')

    def train(self, model: Tacotron, optimizer: Optimizer) -> None:
        for i, session_params in enumerate(hp.tts_schedule, 1):
            r, lr, max_step, bs = session_params
            if model.get_step() < max_step:
                train_set, val_set = get_tts_datasets(path=self.paths.data,
                                                      batch_size=bs,
                                                      r=r,
                                                      model_type='tacotron')
                session = TTSSession(index=i,
                                     r=r,
                                     lr=lr,
                                     max_step=max_step,
                                     bs=bs,
                                     train_set=train_set,
                                     val_set=val_set)
                self.train_session(model, optimizer, session)

    def train_session(self, model: Tacotron, optimizer: Optimizer,
                      session: TTSSession) -> None:
        current_step = model.get_step()
        training_steps = session.max_step - current_step
        total_iters = len(session.train_set)
        epochs = training_steps // total_iters + 1
        model.r = session.r
        simple_table([(f'Steps with r={session.r}',
                       str(training_steps // 1000) + 'k Steps'),
                      ('Batch Size', session.bs),
                      ('Learning Rate', session.lr),
                      ('Outputs/Step (r)', model.r)])
        for g in optimizer.param_groups:
            g['lr'] = session.lr

        loss_avg = Averager()
        duration_avg = Averager()
        device = next(
            model.parameters()).device  # use same device as model parameters
        for e in range(1, epochs + 1):
            for i, (x, m, ids, x_lens,
                    mel_lens) in enumerate(session.train_set, 1):
                start = time.time()
                model.train()
                x, m = x.to(device), m.to(device)

                m1_hat, m2_hat, attention = model(x, m)

                m1_loss = F.l1_loss(m1_hat, m)
                m2_loss = F.l1_loss(m2_hat, m)
                loss = m1_loss + m2_loss
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hp.tts_clip_grad_norm)
                optimizer.step()
                loss_avg.add(loss.item())
                step = model.get_step()
                k = step // 1000

                duration_avg.add(time.time() - start)
                speed = 1. / duration_avg.get()
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \
                      f'| {speed:#.2} steps/s | Step: {k}k | '

                if step % hp.tts_checkpoint_every == 0:
                    ckpt_name = f'taco_step{k}K'
                    save_checkpoint('tts',
                                    self.paths,
                                    model,
                                    optimizer,
                                    name=ckpt_name,
                                    is_silent=True)

                if step % hp.tts_plot_every == 0:
                    self.generate_plots(model, session)

                _, att_score = attention_score(attention, mel_lens)
                att_score = torch.mean(att_score)
                self.writer.add_scalar('Attention_Score/train', att_score,
                                       model.get_step())
                self.writer.add_scalar('Loss/train', loss, model.get_step())
                self.writer.add_scalar('Params/reduction_factor', session.r,
                                       model.get_step())
                self.writer.add_scalar('Params/batch_size', session.bs,
                                       model.get_step())
                self.writer.add_scalar('Params/learning_rate', session.lr,
                                       model.get_step())

                stream(msg)

            val_loss, val_att_score = self.evaluate(model, session.val_set)
            self.writer.add_scalar('Loss/val', val_loss, model.get_step())
            self.writer.add_scalar('Attention_Score/val', val_att_score,
                                   model.get_step())
            save_checkpoint('tts',
                            self.paths,
                            model,
                            optimizer,
                            is_silent=True)

            loss_avg.reset()
            duration_avg.reset()
            print(' ')

    def evaluate(self, model: Tacotron,
                 val_set: Dataset) -> Tuple[float, float]:
        model.eval()
        val_loss = 0
        val_att_score = 0
        device = next(model.parameters()).device
        for i, (x, m, ids, x_lens, mel_lens) in enumerate(val_set, 1):
            x, m = x.to(device), m.to(device)
            with torch.no_grad():
                m1_hat, m2_hat, attention = model(x, m)
                m1_loss = F.l1_loss(m1_hat, m)
                m2_loss = F.l1_loss(m2_hat, m)
                val_loss += m1_loss.item() + m2_loss.item()
            _, att_score = attention_score(attention, mel_lens)
            val_att_score += torch.mean(att_score).item()

        return val_loss / len(val_set), val_att_score / len(val_set)

    @ignore_exception
    def generate_plots(self, model: Tacotron, session: TTSSession) -> None:
        model.eval()
        device = next(model.parameters()).device
        x, m, ids, x_lens, m_lens = session.val_sample
        x, m = x.to(device), m.to(device)

        m1_hat, m2_hat, att = model(x, m)
        att = np_now(att)[0]
        m1_hat = np_now(m1_hat)[0, :600, :]
        m2_hat = np_now(m2_hat)[0, :600, :]
        m = np_now(m)[0, :600, :]

        att_fig = plot_attention(att)
        m1_hat_fig = plot_mel(m1_hat)
        m2_hat_fig = plot_mel(m2_hat)
        m_fig = plot_mel(m)

        self.writer.add_figure('Ground_Truth_Aligned/attention', att_fig,
                               model.step)
        self.writer.add_figure('Ground_Truth_Aligned/target', m_fig,
                               model.step)
        self.writer.add_figure('Ground_Truth_Aligned/linear', m1_hat_fig,
                               model.step)
        self.writer.add_figure('Ground_Truth_Aligned/postnet', m2_hat_fig,
                               model.step)

        m2_hat_wav = reconstruct_waveform(m2_hat)
        target_wav = reconstruct_waveform(m)

        self.writer.add_audio(tag='Ground_Truth_Aligned/target_wav',
                              snd_tensor=target_wav,
                              global_step=model.step,
                              sample_rate=hp.sample_rate)
        self.writer.add_audio(tag='Ground_Truth_Aligned/postnet_wav',
                              snd_tensor=m2_hat_wav,
                              global_step=model.step,
                              sample_rate=hp.sample_rate)

        m1_hat, m2_hat, att = model.generate(x[0].tolist(),
                                             steps=m_lens[0] + 20)
        att_fig = plot_attention(att)
        m1_hat_fig = plot_mel(m1_hat)
        m2_hat_fig = plot_mel(m2_hat)

        self.writer.add_figure('Generated/attention', att_fig, model.step)
        self.writer.add_figure('Generated/target', m_fig, model.step)
        self.writer.add_figure('Generated/linear', m1_hat_fig, model.step)
        self.writer.add_figure('Generated/postnet', m2_hat_fig, model.step)

        m2_hat_wav = reconstruct_waveform(m2_hat)

        self.writer.add_audio(tag='Generated/target_wav',
                              snd_tensor=target_wav,
                              global_step=model.step,
                              sample_rate=hp.sample_rate)
        self.writer.add_audio(tag='Generated/postnet_wav',
                              snd_tensor=m2_hat_wav,
                              global_step=model.step,
                              sample_rate=hp.sample_rate)
class ForwardTrainer:

    def __init__(self, paths: Paths) -> None:
        self.paths = paths
        self.writer = SummaryWriter(log_dir=paths.forward_log, comment='v1')
        self.l1_loss = MaskedL1()

    def train(self, model: ForwardTacotron, optimizer: Optimizer) -> None:
        for i, session_params in enumerate(hp.forward_schedule, 1):
            lr, max_step, bs = session_params
            if model.get_step() < max_step:
                train_set, val_set = get_tts_datasets(
                    path=self.paths.data, batch_size=bs, r=1, model_type='forward')
                session = TTSSession(
                    index=i, r=1, lr=lr, max_step=max_step,
                    bs=bs, train_set=train_set, val_set=val_set)
                self.train_session(model, optimizer, session)

    def train_session(self, model: ForwardTacotron,
                      optimizer: Optimizer, session: TTSSession) -> None:
        current_step = model.get_step()
        training_steps = session.max_step - current_step
        total_iters = len(session.train_set)
        epochs = training_steps // total_iters + 1
        simple_table([(f'Steps', str(training_steps // 1000) + 'k Steps'),
                      ('Batch Size', session.bs),
                      ('Learning Rate', session.lr)])

        for g in optimizer.param_groups:
            g['lr'] = session.lr

        m_loss_avg = Averager()
        dur_loss_avg = Averager()
        duration_avg = Averager()
        pitch_loss_avg = Averager()
        device = next(model.parameters()).device  # use same device as model parameters
        for e in range(1, epochs + 1):
            for i, (x, m, ids, x_lens, mel_lens, dur, pitch, puncts) in enumerate(
                session.train_set, 1
            ):
                start = time.time()
                model.train()
                x, m, dur, x_lens, mel_lens, pitch, puncts = (
                    x.to(device),
                    m.to(device),
                    dur.to(device),
                    x_lens.to(device),
                    mel_lens.to(device),
                    pitch.to(device),
                    puncts.to(device),
                )
                # print("*" * 20)
                # print(x)
                # print("*" * 20)
                m1_hat, m2_hat, dur_hat, pitch_hat = model(
                    x, m, dur, mel_lens, pitch, puncts
                )
                m1_loss = self.l1_loss(m1_hat, m, mel_lens)
                m2_loss = self.l1_loss(m2_hat, m, mel_lens)
                dur_loss = self.l1_loss(dur_hat.unsqueeze(1), dur.unsqueeze(1), x_lens)
                pitch_loss = self.l1_loss(pitch_hat, pitch.unsqueeze(1), x_lens)
                loss = m1_loss + m2_loss + 0.3 * dur_loss + 0.1 * pitch_loss
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
                optimizer.step()
                m_loss_avg.add(m1_loss.item() + m2_loss.item())
                dur_loss_avg.add(dur_loss.item())
                step = model.get_step()
                k = step // 1000

                duration_avg.add(time.time() - start)
                pitch_loss_avg.add(pitch_loss.item())

                speed = 1. / duration_avg.get()
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Mel Loss: {m_loss_avg.get():#.4} ' \
                      f'| Dur Loss: {dur_loss_avg.get():#.4} | Pitch Loss: {pitch_loss_avg.get():#.4} ' \
                      f'| {speed:#.2} steps/s | Step: {k}k | '

                if step % hp.forward_checkpoint_every == 0:
                    ckpt_name = f'forward_step{k}K'
                    save_checkpoint('forward', self.paths, model, optimizer,
                                    name=ckpt_name, is_silent=True)

                if step % hp.forward_plot_every == 0:
                    self.generate_plots(model, session)

                self.writer.add_scalar('Mel_Loss/train', m1_loss + m2_loss, model.get_step())
                self.writer.add_scalar('Pitch_Loss/train', pitch_loss, model.get_step())
                self.writer.add_scalar('Duration_Loss/train', dur_loss, model.get_step())
                self.writer.add_scalar('Params/batch_size', session.bs, model.get_step())
                self.writer.add_scalar('Params/learning_rate', session.lr, model.get_step())

                stream(msg)

            m_val_loss, dur_val_loss, pitch_val_loss = self.evaluate(model, session.val_set)
            self.writer.add_scalar('Mel_Loss/val', m_val_loss, model.get_step())
            self.writer.add_scalar('Duration_Loss/val', dur_val_loss, model.get_step())
            self.writer.add_scalar('Pitch_Loss/val', pitch_val_loss, model.get_step())
            save_checkpoint('forward', self.paths, model, optimizer, is_silent=True)

            m_loss_avg.reset()
            duration_avg.reset()
            pitch_loss_avg.reset()
            print(' ')

    def evaluate(self, model: ForwardTacotron, val_set: Dataset) -> Tuple[float, float,float]:
        model.eval()
        m_val_loss = 0
        dur_val_loss = 0
        pitch_val_loss = 0
        device = next(model.parameters()).device
        for i, (x, m, ids, x_lens, mel_lens, dur, pitch, puncts) in enumerate(
            val_set, 1
        ):
            x, m, dur, x_lens, mel_lens, pitch, puncts = (
                x.to(device),
                m.to(device),
                dur.to(device),
                x_lens.to(device),
                mel_lens.to(device),
                pitch.to(device),
                puncts.to(device),
            )
            with torch.no_grad():
                m1_hat, m2_hat, dur_hat, pitch_hat = model(
                    x, m, dur, mel_lens, pitch, puncts
                )
                m1_loss = self.l1_loss(m1_hat, m, mel_lens)
                m2_loss = self.l1_loss(m2_hat, m, mel_lens)
                dur_loss = self.l1_loss(dur_hat.unsqueeze(1), dur.unsqueeze(1), x_lens)
                pitch_val_loss += self.l1_loss(pitch_hat, pitch.unsqueeze(1), x_lens)
                m_val_loss += m1_loss.item() + m2_loss.item()
                dur_val_loss += dur_loss.item()
        m_val_loss /= len(val_set)
        dur_val_loss /= len(val_set)
        pitch_val_loss /= len(val_set)
        return m_val_loss, dur_val_loss, pitch_val_loss

    @ignore_exception
    def generate_plots(self, model: ForwardTacotron, session: TTSSession) -> None:
        model.eval()
        device = next(model.parameters()).device
        x, m, ids, x_lens, mel_lens, dur, pitch, puncts = session.val_sample
        x, m, dur, mel_lens, pitch, puncts = (
            x.to(device),
            m.to(device),
            dur.to(device),
            mel_lens.to(device),
            pitch.to(device),
            puncts.to(device),
        )
        m1_hat, m2_hat, dur_hat, pitch_hat = model(x, m, dur, mel_lens, pitch, puncts)
        m1_hat = np_now(m1_hat)[0, :600, :]
        m2_hat = np_now(m2_hat)[0, :600, :]
        m = np_now(m)[0, :600, :]

        m1_hat_fig = plot_mel(m1_hat)
        m2_hat_fig = plot_mel(m2_hat)
        m_fig = plot_mel(m)
        pitch_fig = plot_pitch(np_now(pitch[0]))
        pitch_gta_fig = plot_pitch(np_now(pitch_hat.squeeze()[0]))

        self.writer.add_figure('Pitch/target', pitch_fig, model.step)
        self.writer.add_figure('Pitch/ground_truth_aligned', pitch_gta_fig, model.step)
        self.writer.add_figure('Ground_Truth_Aligned/target', m_fig, model.step)
        self.writer.add_figure('Ground_Truth_Aligned/linear', m1_hat_fig, model.step)
        self.writer.add_figure('Ground_Truth_Aligned/postnet', m2_hat_fig, model.step)

        m2_hat_wav = reconstruct_waveform(m2_hat)
        target_wav = reconstruct_waveform(m)

        self.writer.add_audio(
            tag='Ground_Truth_Aligned/target_wav', snd_tensor=target_wav,
            global_step=model.step, sample_rate=hp.sample_rate)
        self.writer.add_audio(
            tag='Ground_Truth_Aligned/postnet_wav', snd_tensor=m2_hat_wav,
            global_step=model.step, sample_rate=hp.sample_rate)
        
        m1_hat, m2_hat, dur_hat, pitch_hat = model.generate(
            x[0, : x_lens[0]].tolist(), puncts[0, : x_lens[0]].tolist()
        )

        m1_hat_fig = plot_mel(m1_hat)
        m2_hat_fig = plot_mel(m2_hat)

        pitch_gen_fig = plot_pitch(np_now(pitch_hat.squeeze()))

        self.writer.add_figure('Pitch/generated', pitch_gen_fig, model.step)
        self.writer.add_figure('Generated/target', m_fig, model.step)
        self.writer.add_figure('Generated/linear', m1_hat_fig, model.step)
        self.writer.add_figure('Generated/postnet', m2_hat_fig, model.step)

        m2_hat_wav = reconstruct_waveform(m2_hat)

        self.writer.add_audio(
            tag='Generated/target_wav', snd_tensor=target_wav,
            global_step=model.step, sample_rate=hp.sample_rate)
        self.writer.add_audio(
            tag='Generated/postnet_wav', snd_tensor=m2_hat_wav,
            global_step=model.step, sample_rate=hp.sample_rate)
Beispiel #14
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(STYLER()).to(device)
    print("Model Has Been Defined")

    # Parameters
    num_param = utils.get_param_num(model)
    text_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.text_encoder)
    audio_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.audio_encoder)
    predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\
         + utils.get_param_num(model.module.style_modeling.pitch_predictor)\
              + utils.get_param_num(model.module.style_modeling.energy_predictor)
    decoder = utils.get_param_num(model.module.decoder)
    print('Number of Model Parameters          :', num_param)
    print('Number of Text Encoder Parameters   :', text_encoder)
    print('Number of Audio Encoder Parameters  :', audio_encoder)
    print('Number of Predictor Parameters      :', predictors)
    print('Number of Decoder Parameters        :', decoder)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = STYLERLoss().to(device)
    DATLoss = DomainAdversarialTrainingLoss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path())
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    vocoder = utils.get_vocoder()

    # Init logger
    log_path = hp.log_path()
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path()
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i*hp.batch_size + j + args.restore_step + \
                    epoch*len(loader)*hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                mel_aug = torch.from_numpy(
                    data_of_batch["mel_aug"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                f0_norm = torch.from_numpy(
                    data_of_batch["f0_norm"]).float().to(device)
                f0_norm_aug = torch.from_numpy(
                    data_of_batch["f0_norm_aug"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                energy_input = torch.from_numpy(
                    data_of_batch["energy_input"]).float().to(device)
                energy_input_aug = torch.from_numpy(
                    data_of_batch["energy_input_aug"]).float().to(device)
                speaker_embed = torch.from_numpy(
                    data_of_batch["speaker_embed"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model(
                    text,
                    mel_target,
                    mel_aug,
                    f0_norm,
                    energy_input,
                    src_len,
                    mel_len,
                    D,
                    f0,
                    energy,
                    max_src_len,
                    max_mel_len,
                    speaker_embed=speaker_embed)

                # Cal Loss Clean
                mel_output, mel_postnet_output = mel_outputs[
                    0], mel_postnet_outputs[0]
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\
                        aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device))

                # Cal Loss Noisy
                mel_output_noisy, mel_postnet_output_noisy = mel_outputs[
                    1], mel_postnet_outputs[1]
                mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss(
                    mel_output_noisy, mel_postnet_output_noisy, mel_aug,
                    ~mel_mask)

                # Forward DAT
                enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat(
                    mel_aug, f0_norm_aug, energy_input_aug, mel_aug)
                duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder(
                    enc_cat, mel_len, src_len, mask=None)
                aug_posterior_d = model.module.style_modeling.augmentation_classifier_d(
                    duration_encoding)
                aug_posterior_p = model.module.style_modeling.augmentation_classifier_p(
                    pitch_encoding)
                aug_posterior_e = model.module.style_modeling.augmentation_classifier_e(
                    energy_encoding)

                # Cal Loss DAT
                classifier_loss_a_dat = DATLoss(
                    (aug_posterior_d, aug_posterior_p, aug_posterior_e),
                    torch.ones(mel_target.size(0)).long().to(device))

                # Total loss
                total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\
                    + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat)

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                m_n_l = mel_noisy_loss.item()
                m_p_n_l = mel_postnet_noisy_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                cl_a = classifier_loss_a.item()
                cl_a_dat = classifier_loss_a_dat.item()

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step == 1 or current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                            m_p_n_l, current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                            current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step == 1 or current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_aug_torch = mel_aug[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_aug = mel_aug[0, :length].detach().cpu().transpose(
                        0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_noisy_torch = mel_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_noisy = mel_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet_noisy_torch = mel_postnet_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_noisy = mel_postnet_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    # Audio.tools.inv_mel_spec(mel, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n")))
                    # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n")))

                    wav_mel = utils.vocoder_infer(
                        mel_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "c",
                                                       hp.vocoder)))
                    wav_mel_postnet = utils.vocoder_infer(
                        mel_postnet_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_ground_truth = utils.vocoder_infer(
                        mel_target_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_mel_noisy = utils.vocoder_infer(
                        mel_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "n",
                                                       hp.vocoder)))
                    wav_mel_postnet_noisy = utils.vocoder_infer(
                        mel_postnet_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "n", hp.vocoder)))
                    wav_aug = utils.vocoder_infer(
                        mel_aug_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "n", hp.vocoder)))

                    # Model duration prediction
                    log_duration_output = log_duration_output[
                        0, :src_len[0].item()].detach().cpu()  # [seg_len]
                    log_duration_output = torch.clamp(torch.round(
                        torch.exp(log_duration_output) - hp.log_offset),
                                                      min=0).int()
                    model_duration = utils.get_alignment_2D(
                        log_duration_output).T  # [seg_len, mel_len]
                    model_duration = utils.plot_alignment([model_duration])

                    # Model mel prediction
                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()
                    mel_predicted = utils.plot_data(
                        [(mel_postnet.numpy(), f0_output, energy_output),
                         (mel_target.numpy(), f0, energy)], [
                             'Synthetized Spectrogram Clean',
                             'Ground-Truth Spectrogram'
                         ],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "c")))
                    mel_noisy_predicted = utils.plot_data(
                        [(mel_postnet_noisy.numpy(), f0_output, energy_output),
                         (mel_aug.numpy(), f0, energy)],
                        ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "n")))

                    # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045
                    wav_ground_truth = wav_ground_truth / max(wav_ground_truth)
                    wav_mel = wav_mel / max(wav_mel)
                    wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet)
                    wav_aug = wav_aug / max(wav_aug)
                    wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy)
                    wav_mel_postnet_noisy = wav_mel_postnet_noisy / max(
                        wav_mel_postnet_noisy)

                    train_logger.add_image("model_duration",
                                           model_duration,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Clean",
                                           mel_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Noisy",
                                           mel_noisy_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_audio("Clean/wav_ground_truth",
                                           wav_ground_truth,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel",
                                           wav_mel,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel_postnet",
                                           wav_mel_postnet,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_aug",
                                           wav_aug,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_noisy",
                                           wav_mel_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_postnet_noisy",
                                           wav_mel_postnet_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)

                if current_step == 1 or current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\
                            + hp.dat_weight*(cl_a + cl_a_dat)

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                              m_p_n_l, current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Beispiel #15
0
def train(rank, args, hp, hp_str):
    # if hp.train.num_gpus > 1:
    #     init_process_group(backend=hp.dist.dist_backend, init_method=hp.dist.dist_url,
    #                        world_size=hp.dist.world_size * hp.train.num_gpus, rank=rank)

    torch.cuda.manual_seed(hp.train.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(hp.model.in_channels,
                          hp.model.out_channels).to(device)
    specd = SpecDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)
    stft_loss = MultiResolutionSTFTLoss()

    if rank == 0:
        print(generator)
        os.makedirs(hp.logs.chkpt_dir, exist_ok=True)
        print("checkpoints directory : ", hp.logs.chkpt_dir)

    if os.path.isdir(hp.logs.chkpt_dir):
        cp_g = scan_checkpoint(hp.logs.chkpt_dir, 'g_')
        cp_do = scan_checkpoint(hp.logs.chkpt_dir, 'do_')

    steps = 0
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        specd.load_state_dict(state_dict_do['specd'])
        msd.load_state_dict(state_dict_do['msd'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if hp.train.num_gpus > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        specd = DistributedDataParallel(specd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(
        generator.parameters(),
        hp.train.adamG.lr,
        betas=[hp.train.adamG.beta1, hp.train.adamG.beta2])
    optim_d = torch.optim.AdamW(
        itertools.chain(msd.parameters(), specd.parameters()),
        hp.train.adamD.lr,
        betas=[hp.train.adamD.beta1, hp.train.adamD.beta2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch)
    # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(args)

    trainset = MelDataset(training_filelist,
                          hp.data.input_wavs,
                          hp.data.output_wavs,
                          hp.audio.segment_length,
                          hp.audio.filter_length,
                          hp.audio.n_mel_channels,
                          hp.audio.hop_length,
                          hp.audio.win_length,
                          hp.audio.sampling_rate,
                          hp.audio.mel_fmin,
                          hp.audio.mel_fmax,
                          n_cache_reuse=0,
                          shuffle=False if hp.train.num_gpus > 1 else True,
                          fmax_loss=None,
                          device=device)

    train_sampler = DistributedSampler(
        trainset) if hp.train.num_gpus > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=hp.train.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=hp.train.batch_size,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        validset = MelDataset(validation_filelist,
                              hp.data.input_wavs,
                              hp.data.output_wavs,
                              hp.audio.segment_length,
                              hp.audio.filter_length,
                              hp.audio.n_mel_channels,
                              hp.audio.hop_length,
                              hp.audio.win_length,
                              hp.audio.sampling_rate,
                              hp.audio.mel_fmin,
                              hp.audio.mel_fmax,
                              split=False,
                              shuffle=False,
                              n_cache_reuse=0,
                              fmax_loss=None,
                              device=device)
        validation_loader = DataLoader(validset,
                                       num_workers=1,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(hp.logs.chkpt_dir, 'logs'))

    generator.train()
    specd.train()
    msd.train()
    with_postnet = False
    for epoch in range(max(0, last_epoch), args.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if hp.train.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            if steps > hp.train.postnet_start_steps:
                with_postnet = True
            x, y, file, _, y_mel_loss = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel_loss = torch.autograd.Variable(
                y_mel_loss.to(device, non_blocking=True))
            # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
            x = x.unsqueeze(1)
            y = y.unsqueeze(1)
            before_y_g_hat, y_g_hat = generator(x, with_postnet)

            if y_g_hat is not None:
                y_g_hat_mel = mel_spectrogram(
                    y_g_hat.squeeze(1), hp.audio.filter_length,
                    hp.audio.n_mel_channels, hp.audio.sampling_rate,
                    hp.audio.hop_length, hp.audio.win_length,
                    hp.audio.mel_fmin, None)

            if steps > hp.train.discriminator_train_start_steps:
                for _ in range(hp.train.rep_discriminator):
                    optim_d.zero_grad()

                    # SpecD
                    y_df_hat_r, y_df_hat_g, _, _ = specd(
                        y_mel_loss, y_g_hat_mel.detach())
                    loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
                        y_df_hat_r, y_df_hat_g)

                    # MSD
                    y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
                    loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
                        y_ds_hat_r, y_ds_hat_g)

                    loss_disc_all = loss_disc_s + loss_disc_f

                    loss_disc_all.backward()
                    optim_d.step()

            before_y_g_hat_mel = mel_spectrogram(
                before_y_g_hat.squeeze(1), hp.audio.filter_length,
                hp.audio.n_mel_channels, hp.audio.sampling_rate,
                hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin,
                None)
            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            # before_loss_mel = F.l1_loss(y_mel_loss, before_y_g_hat_mel)
            sc_loss, mag_loss = stft_loss(
                before_y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
            before_loss_mel = sc_loss + mag_loss

            # L1 Sample Loss
            before_loss_sample = F.l1_loss(y, before_y_g_hat)
            loss_gen_all = before_loss_mel + before_loss_sample

            if y_g_hat is not None:
                # L1 Mel-Spectrogram Loss
                # loss_mel = F.l1_loss(y_mel_loss, y_g_hat_mel)
                sc_loss_, mag_loss_ = stft_loss(
                    y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
                loss_mel = sc_loss_ + mag_loss_
                # L1 Sample Loss
                loss_sample = F.l1_loss(y, y_g_hat)
                loss_gen_all += loss_mel + loss_sample

            if steps > hp.train.discriminator_train_start_steps:
                y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = specd(
                    y_mel_loss, y_g_hat_mel)
                y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
                loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
                loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
                loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
                loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
                loss_gen_all += hp.model.lambda_adv * (
                    loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f)

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % args.stdout_interval == 0:
                    with torch.no_grad():
                        mel_error = F.l1_loss(y_mel_loss,
                                              before_y_g_hat_mel).item()
                        sample_error = F.l1_loss(y, before_y_g_hat)

                    print(
                        'Steps : {:d}, Gen Loss Total : {:4.3f}, Sample Error: {:4.3f}, '
                        'Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.format(
                            steps, loss_gen_all, sample_error, mel_error,
                            time.time() - start_b))

                # checkpointing
                if steps % hp.logs.save_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(
                        hp.logs.chkpt_dir, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'generator':
                            (generator.module if hp.train.num_gpus > 1 else
                             generator).state_dict()
                        })
                    checkpoint_path = "{}/do_{:08d}".format(
                        hp.logs.chkpt_dir, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'specd': (specd.module if hp.train.num_gpus > 1
                                      else specd).state_dict(),
                            'msd': (msd.module if hp.train.num_gpus > 1 else
                                    msd).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch,
                            'hp_str':
                            hp_str
                        })

                # Tensorboard summary logging
                if steps % hp.logs.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all,
                                  steps)
                    sw.add_scalar("training/mel_spec_error", mel_error, steps)

                # Validation
                if steps % hp.logs.validation_interval == 0:  # and steps != 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            x, y, file, y_mel, y_mel_loss = batch
                            x = x.unsqueeze(1)
                            y = y.unsqueeze(1).to(device)
                            before_y_g_hat, y_g_hat = generator(x.to(device))
                            y_mel_loss = torch.autograd.Variable(
                                y_mel_loss.to(device, non_blocking=True))
                            y_g_hat_mel = mel_spectrogram(
                                before_y_g_hat.squeeze(1),
                                hp.audio.filter_length,
                                hp.audio.n_mel_channels,
                                hp.audio.sampling_rate, hp.audio.hop_length,
                                hp.audio.win_length, hp.audio.mel_fmin, None)
                            val_err_tot += F.l1_loss(y_mel_loss,
                                                     y_g_hat_mel).item()
                            val_err_tot += F.l1_loss(y, before_y_g_hat).item()
                            if y_g_hat is not None:
                                val_err_tot += F.l1_loss(y, y_g_hat).item()

                            if j <= 4:
                                if steps == 0:
                                    sw.add_audio('gt_noise/y_{}'.format(j),
                                                 x[0], steps,
                                                 hp.audio.sampling_rate)
                                    sw.add_audio('gt_clean/y_{}'.format(j),
                                                 y[0], steps,
                                                 hp.audio.sampling_rate)
                                    sw.add_figure(
                                        'gt/y_spec_clean_{}'.format(j),
                                        plot_spectrogram(y_mel[0]), steps)

                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             before_y_g_hat[0], steps,
                                             hp.audio.sampling_rate)
                                if y_g_hat is not None:
                                    sw.add_audio(
                                        'generated/y_hat_after_{}'.format(j),
                                        y_g_hat[0], steps,
                                        hp.audio.sampling_rate)
                                y_hat_spec = mel_spectrogram(
                                    before_y_g_hat.squeeze(1),
                                    hp.audio.filter_length,
                                    hp.audio.n_mel_channels,
                                    hp.audio.sampling_rate,
                                    hp.audio.hop_length, hp.audio.win_length,
                                    hp.audio.mel_fmin, None)
                                sw.add_figure(
                                    'generated/y_hat_spec_{}'.format(j),
                                    plot_spectrogram(
                                        y_hat_spec.squeeze(0).cpu().numpy()),
                                    steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/mel_spec_error", val_err,
                                      steps)

                    generator.train()

            steps += 1

        # scheduler_g.step()
        # scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Beispiel #16
0
def main():
    # Reproducibility
    np.random.seed(12345)
    torch.manual_seed(12345)

    # Preparation
    config = get_parameters()

    # Logging configuration
    writer = None
    if config.tensorboard:
        path_tensorboard = f'{config.logging_dir}/{config.experiment_description}'
        if config.debug_mode:  # Clear tensorboard when debugging
            if os.path.exists(path_tensorboard):
                shutil.rmtree(path_tensorboard)
        writer = SummaryWriter(path_tensorboard)

    data_loader_train, data_loader_valid, data_loader_test = get_data(config)

    if config.use_time_freq:
        transforms = get_time_frequency_transform(config)
    else:
        transforms = None

    # =====================================================================
    # Visualize some data
    tmp_audio = None
    tmp_spec = None
    tmp_data, targets, _ = data_loader_train.dataset[
        0]  # audio is [channels, timesteps]

    # Is the data audio or image?
    if len(tmp_data.shape) == 2:
        tmp_audio = tmp_data
    else:
        tmp_spec = tmp_data

    if config.use_time_freq:
        tmp_spec = transforms(
            tmp_audio)  # spec is [channels, freq_bins, frames]

    if tmp_spec is not None:
        utils.show_spectrogram(tmp_spec, config)

    if writer is not None:
        if tmp_audio is not None:
            # Store 5 secs of audio
            ind = tmp_audio.shape[-1] if tmp_audio.shape[
                -1] <= 5 * config.original_fs else 5 * config.original_fs
            writer.add_audio('input_audio', tmp_audio[:, 0:ind], None,
                             config.original_fs)

            tmp_audios = []
            fnames = []
            for i in range(4):
                aud, _, fn = data_loader_train.dataset.dataset[i]
                fnames.append(fn)
                tmp_audios.append(aud)
            writer.add_figure(
                'input_waveform',
                utils.show_waveforms_batch(tmp_audios, fnames, config), None)

        # Analyze some spectrograms
        if tmp_spec is not None:
            img_tform = tforms_vision.Compose([
                tforms_vision.ToPILImage(),
                tforms_vision.ToTensor(),
            ])

            writer.add_image('input_spec', img_tform(tmp_spec),
                             None)  # Raw tensor
            writer.add_figure('input_spec_single',
                              utils.show_spectrogram(tmp_spec, config),
                              None)  # Librosa

            if config.use_time_freq:
                tmp_specs = []
                fnames = []
                for i in range(4):
                    aud, _, fn = data_loader_train.dataset.dataset[i]
                    tmp_specs.append(transforms(aud))
                    fnames.append(fn)

                writer.add_figure(
                    'input_spec_batch',
                    utils.show_spectrogram_batch(tmp_specs, fnames, config),
                    None)
                writer.add_figure('input_spec_histogram',
                                  utils.get_spectrogram_histogram(tmp_specs),
                                  None)
                del tmp_specs, fnames, aud, fn, i

    # Class Histograms
    if not config.dataset_skip_class_hist:
        fig_classes = utils.get_class_histograms(
            data_loader_train.dataset,
            data_loader_valid.dataset,
            data_loader_test.dataset,
            one_hot_encoder=utils.OneHot
            if config.dataset == 'MNIST' else None,
            data_limit=200 if config.debug_mode else None)
        if writer is not None:
            writer.add_figure('class_histogram', fig_classes, None)

    # =====================================================================
    # Train and Test
    solver = Solver(data_loader_train, data_loader_valid, data_loader_test,
                    config, writer, transforms)
    solver.train()
    scores, true_class, pred_scores = solver.test()

    # =====================================================================
    # Save results

    np.save(open(os.path.join(config.result_dir, 'true_class.npy'), 'wb'),
            true_class)
    np.save(open(os.path.join(config.result_dir, 'pred_scores.npy'), 'wb'),
            pred_scores)

    utils.compare_predictions(true_class, pred_scores, config.result_dir)

    if writer is not None:
        writer.close()
Beispiel #17
0
class DurationExtractor(nn.Module):
    """The teacher model for duration extraction"""
    def __init__(
            self,
            adam_lr=0.002,
            warmup_epochs=30,
            init_scale=0.25,
            guided_att_sigma=0.3,
            device='cuda'
    ):
        super(DurationExtractor, self).__init__()

        self.txt_encoder = ConvTextEncoder()
        self.audio_encoder = ConvAudioEncoder()
        self.audio_decoder = ConvAudioDecoder()
        self.attention = ScaledDotAttention()
        self.collate = Collate(device=device)

        # optim
        self.optimizer = torch.optim.Adam(self.parameters(), lr=adam_lr)
        self.scheduler = NoamScheduler(self.optimizer, warmup_epochs, init_scale)

        # losses
        self.loss_l1 = l1_masked
        self.loss_att = GuidedAttentionLoss(guided_att_sigma)

        # device
        self.device=device
        self.to(self.device)
        print(f'Model sent to {self.device}')

        # helper vars
        self.checkpoint = None
        self.epoch = 0
        self.step = 0

        repo = git.Repo(search_parent_directories=True)
        self.git_commit = repo.head.object.hexsha

    def to_device(self, device):
        print(f'Sending network to {device}')
        self.device = device
        self.to(device)
        return self

    def save(self):

        if self.checkpoint is not None:
            os.remove(self.checkpoint)
        self.checkpoint = os.path.join(self.logger.log_dir, f'{time.strftime("%Y-%m-%d")}_checkpoint_step{self.step}.pth')
        torch.save(
            {
                'epoch': self.epoch,
                'step': self.step,
                'state_dict': self.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
                'git_commit': self.git_commit
            },
            self.checkpoint)

    def load(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        self.epoch = checkpoint['epoch']
        self.step = checkpoint['step']
        self.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])

        commit = checkpoint['git_commit']
        if commit != self.git_commit:
            print(f'Warning: the loaded checkpoint was trained on commit {commit}, but you are on {self.git_commit}')
        self.checkpoint = None  # prevent overriding old checkpoint
        return self

    def forward(self, phonemes, spectrograms, len_phonemes, training=False):
        """
        :param phonemes: (batch, alphabet, time), padded phonemes
        :param spectrograms: (batch, freq, time), padded spectrograms
        :param len_phonemes: list of phoneme lengths
        :return: decoded_spectrograms, attention_weights
        """
        spectrs = ZeroPad2d((0,0,1,0))(spectrograms)[:, :-1, :]  # move this to encoder?
        keys, values = self.txt_encoder(phonemes)
        queries = self.audio_encoder(spectrs)

        att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                        lengths=len_phonemes,
                        dim=-1).to(self.device)

        if hp.positional_encoding:
            keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
            queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device)

        attention, weights = self.attention(queries, keys, values, mask=att_mask)
        decoded = self.audio_decoder(attention + queries)
        return decoded, weights

    def generating(self, mode):
        """Put the module into mode for sequential generation"""
        for module in self.children():
            if hasattr(module, 'generating'):
                module.generating(mode)

    def generate(self, phonemes, len_phonemes, steps=False, window=3, spectrograms=None):
        """Sequentially generate spectrogram from phonemes

        If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing)
        If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion
        Uses layer-level caching for faster inference.

        :param phonemes: Padded phoneme indices
        :param len_phonemes: Length of each sentence in `phonemes` (list of lengths)
        :param steps: How many steps to generate
        :param window: Window size for attention masking
        :param spectrograms: Padded spectrograms
        :return: Generated spectrograms
        """
        self.generating(True)
        self.train(False)

        assert steps or (spectrograms is not None)
        steps = steps if steps else spectrograms.shape[1]

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)
            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps, w=1).to(self.device)

            if spectrograms is None:
                dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device)
            else:
                input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :]

            weights, decoded = None, None

            if window is not None:
                shape = (len(phonemes), 1, phonemes.shape[-1])
                idx = torch.zeros(len(phonemes), 1, phonemes.shape[-1]).to(phonemes.device)
                att_mask = idx_mask(shape, idx, window)
            else:
                att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                                lengths=len_phonemes,
                                dim=-1).to(self.device)

            for i in range(steps):
                if spectrograms is None:
                    queries = self.audio_encoder(dec)
                else:
                    queries = self.audio_encoder(input[:, i:i+1, :])

                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                dec = self.audio_decoder(att + queries)
                weights = w if weights is None else torch.cat((weights, w), dim=1)
                decoded = dec if decoded is None else torch.cat((decoded, dec), dim=1)
                if window is not None:
                    idx = torch.argmax(w, dim=-1).unsqueeze(2).float()
                    att_mask = idx_mask(shape, idx, window)

        self.generating(False)
        return decoded, weights

    def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0,1)):
        """Naive generation without layer-level caching for testing purposes"""

        self.train(False)

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)

            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps, w=1).to(self.device)

            dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device)

            weights = None

            att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                            lengths=len_phonemes,
                            dim=-1).to(self.device)

            for i in range(steps):
                print(i)
                queries = self.audio_encoder(dec)
                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                d = self.audio_decoder(att + queries)
                d = d[:, -1:]
                w = w[:, -1:]
                weights = w if weights is None else torch.cat((weights, w), dim=1)
                dec = torch.cat((dec, d), dim=1)

                if window is not None:
                    att_mask = median_mask(weights, window=window)

        return dec[:, 1:, :], weights

    def fit(self, batch_size, logdir, epochs=1, grad_clip=1, checkpoint_every=10):
        self.grad_clip = grad_clip
        self.logger = SummaryWriter(logdir)

        train_loader = self.train_dataloader(batch_size)
        valid_loader = self.val_dataloader(batch_size)

        # continue training from self.epoch if checkpoint loaded
        for e in range(self.epoch + 1, self.epoch + 1 + epochs):
            self.epoch = e
            train_losses = self._train_epoch(train_loader)
            valid_losses = self._validate(valid_loader)

            self.scheduler.step()
            self.logger.add_scalar('train/learning_rate', self.optimizer.param_groups[0]['lr'], self.epoch)
            if not e % checkpoint_every:
                self.save()

            print(f'Epoch {e} | Train - l1: {train_losses[0]}, guided_att: {train_losses[1]}| '
                  f'Valid - l1: {valid_losses[0]}, guided_att: {valid_losses[1]}|')

    def _train_epoch(self, dataloader):
        self.train()

        t_l1, t_att = 0, 0
        for i, batch in enumerate(Bar(dataloader)):
            self.optimizer.zero_grad()
            spectrs, slen, phonemes, plen, text = batch

            s = add_random_noise(spectrs, hp.noise)
            s = degrade_some(self, s, phonemes, plen, hp.feed_ratio, repeat=hp.feed_repeat)
            s = frame_dropout(s, hp.replace_ratio)

            out, att_weights = self.forward(phonemes, s, plen)

            l1 = self.loss_l1(out, spectrs, slen)
            l_att = self.loss_att(att_weights, slen, plen)

            loss = l1 + l_att
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
            self.optimizer.step()
            self.step += 1

            t_l1 += l1.item()
            t_att += l_att.item()

            self.logger.add_scalar(
                'batch/total', loss.item(), self.step
            )

        # report average cost per batch
        self.logger.add_scalar('train/l1', t_l1 / i, self.epoch)
        self.logger.add_scalar('train/guided_att', t_att / i, self.epoch)
        return t_l1 / i, t_att / i

    def _validate(self, dataloader):
        self.eval()

        t_l1, t_att = 0,0
        for i, batch in enumerate(dataloader):
            spectrs, slen, phonemes, plen, text = batch
            # generate sequentially
            out, att_weights = self.generate(phonemes, plen, steps=spectrs.shape[1], window=None)

            # generate in supervised fashion - for visualisation only
            with torch.no_grad():
                out_s, att_s = self.forward(phonemes, spectrs, plen)

            l1 = self.loss_l1(out, spectrs, slen)
            l_att = self.loss_att(att_weights, slen, plen)
            t_l1 += l1.item()
            t_att += l_att.item()

            fig = display_spectr_alignment(out[-1, :slen[-1]],
                                           att_weights[-1][:slen[-1], :plen[-1]],
                                           out_s[-1, :slen[-1]], att_s[-1][:slen[-1], :plen[-1]],
                                           text[-1])
            self.logger.add_figure(text[-1], fig, self.epoch)

            if not self.epoch % 10:
                spec = self.collate.norm.inverse(out[-1:]) # TODO: this fails if we do not standardize!
                sound, length = self.collate.stft.spec2wav(spec.transpose(1, 2), slen[-1:])
                sound = sound[0, :length[0]]
                self.logger.add_audio(text[-1], sound.detach().cpu().numpy(), self.epoch, sample_rate=22050) # TODO: parameterize

        # report average cost per batch
        self.logger.add_scalar('valid/l1', t_l1 / i, self.epoch)
        self.logger.add_scalar('valid/guided_att', t_att / i, self.epoch)
        return t_l1/i, t_att/i

    def train_dataloader(self, batch_size):
        return DataLoader(AudioDataset(HPText.dataset, start_idx=0, end_idx=HPText.num_train, durations=False), batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=True)

    def val_dataloader(self, batch_size):
        dataset = AudioDataset(HPText.dataset, start_idx=HPText.num_train, end_idx=HPText.num_valid, durations=False)
        return DataLoader(dataset, batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=False, sampler=SequentialSampler(dataset))
Beispiel #18
0
def main(args):
    os.environ['KMP_WARNINGS'] = '0'
    torch.cuda.manual_seed_all(1)
    np.random.seed(0)
    print(args.model_name)
    print(args.alpha)
    # filter array
    num_features = [
        args.features * i
        for i in range(1, args.levels + 2 + args.levels_without_sample)
    ]

    # 確定 輸出大小
    target_outputs = int(args.output_size * args.sr)
    # 訓練才保存模型設定參數

    # 設定teacher and student and student_for_backward 超參數

    student_KD = Waveunet(args.channels,
                          num_features,
                          args.channels,
                          levels=args.levels,
                          encoder_kernel_size=args.encoder_kernel_size,
                          decoder_kernel_size=args.decoder_kernel_size,
                          target_output_size=target_outputs,
                          depth=args.depth,
                          strides=args.strides,
                          conv_type=args.conv_type,
                          res=args.res)
    KD_optimizer = Adam(params=student_KD.parameters(), lr=args.lr)
    print(25 * '=' + 'model setting' + 25 * '=')
    print('student_KD: ', student_KD.shapes)
    if args.cuda:
        student_KD = utils.DataParallel(student_KD)
        print("move student_KD to gpu\n")
        student_KD.cuda()

    state = {"step": 0, "worse_epochs": 0, "epochs": 0, "best_pesq": -np.Inf}
    if args.load_model is not None:
        print("Continuing full model from checkpoint " + str(args.load_model))
        state = utils.load_model(student_KD, KD_optimizer, args.load_model,
                                 args.cuda)
    dataset = get_folds(args.dataset_dir, args.outside_test)
    log_dir, checkpoint_dir, result_dir = utils.mkdir_and_get_path(args)
    # print(model)
    if args.test is False:
        writer = SummaryWriter(log_dir)
        # set hypeparameter
        # printing hypeparameters info
        print(25 * '=' + 'printing hypeparameters info' + 25 * '=')

        with open(os.path.join(log_dir, 'config.json'), 'w') as f:
            json.dump(args.__dict__, f, indent=5)
        print('saving commandline_args')
        student_size = sum(p.numel() for p in student_KD.parameters())
        print('student_parameter count: ', str(student_size))
        if args.teacher_model is not None:
            teacher_num_features = [
                24 * i
                for i in range(1, args.levels + 2 + args.levels_without_sample)
            ]
            teacher_model = Waveunet(
                args.channels,
                teacher_num_features,
                args.channels,
                levels=args.levels,
                encoder_kernel_size=args.encoder_kernel_size,
                decoder_kernel_size=args.decoder_kernel_size,
                target_output_size=target_outputs,
                depth=args.depth,
                strides=args.strides,
                conv_type=args.conv_type,
                res=args.res)

            if args.cuda:
                teacher_model = utils.DataParallel(teacher_model)
                teacher_model.cuda()
                # print("move teacher to gpu\n")
            student_size = sum(p.numel() for p in student_KD.parameters())
            teacher_size = sum(p.numel() for p in teacher_model.parameters())
            print('student_parameter count: ', str(student_size))
            print('teacher_model_parameter count: ', str(teacher_size))
            print(f'compression raito :{100*(student_size/teacher_size)}%')
            if args.teacher_model is not None:
                print("load teacher model" + str(args.teacher_model))
                _ = utils.load_model(teacher_model, None, args.teacher_model,
                                     args.cuda)
                teacher_model.eval()

        # If not data augmentation, at least crop targets to fit model output shape
        crop_func = partial(crop, shapes=student_KD.shapes)
        ### DATASET
        train_data = SeparationDataset(dataset,
                                       "train",
                                       args.sr,
                                       args.channels,
                                       student_KD.shapes,
                                       False,
                                       args.hdf_dir,
                                       audio_transform=crop_func)
        val_data = SeparationDataset(dataset,
                                     "test",
                                     args.sr,
                                     args.channels,
                                     student_KD.shapes,
                                     False,
                                     args.hdf_dir,
                                     audio_transform=crop_func)
        dataloader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            worker_init_fn=utils.worker_init_fn,
            pin_memory=True)

        # Set up the loss function
        if args.loss == "L1":
            criterion = nn.L1Loss()
        elif args.loss == "L2":
            criterion = nn.MSELoss()
        else:
            raise NotImplementedError("Couldn't find this loss!")
        My_criterion = customLoss()

        ### TRAINING START
        print('TRAINING START')
        batch_num = (len(train_data) // args.batch_size)
        while state["epochs"] < 100:
            #     if state["epochs"]<10:
            #         args.alpha=1
            #     else:
            #         args.alpha=0
            # print('fix alpha:',args.alpha)
            memory_alpha = []
            print("epoch:" + str(state["epochs"]))
            student_KD.train()
            # monitor_value
            avg_origin_loss = 0
            with tqdm(total=len(dataloader)) as pbar:
                for example_num, (x, targets) in enumerate(dataloader):
                    if args.cuda:
                        x = x.cuda()
                        targets = targets.cuda()
                    if args.teacher_model is not None:
                        # Set LR for this iteration
                        #print('base_model from KD')

                        utils.set_cyclic_lr(KD_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        _, avg_student_KD_loss = utils.compute_loss(
                            student_KD,
                            x,
                            targets,
                            criterion,
                            compute_grad=False)

                        KD_optimizer.zero_grad()
                        KD_outputs, KD_hard_loss, KD_loss, KD_soft_loss = utils.KD_compute_loss(
                            student_KD,
                            teacher_model,
                            x,
                            targets,
                            My_criterion,
                            alpha=args.alpha,
                            compute_grad=True,
                            KD_method=args.KD_method)
                        KD_optimizer.step()

                        # calculate backwarded model MSE

                        avg_origin_loss += avg_student_KD_loss / batch_num

                        # add to tensorboard
                        writer.add_scalar("KD_loss", KD_loss, state["step"])
                        writer.add_scalar("KD_hard_loss", KD_hard_loss,
                                          state["step"])
                        writer.add_scalar("KD_soft_loss", KD_soft_loss,
                                          state["step"])
                    else:  # no KD training
                        utils.set_cyclic_lr(KD_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        KD_optimizer.zero_grad()
                        KD_outputs, KD_hard_loss = utils.compute_loss(
                            student_KD,
                            x,
                            targets,
                            nn.MSELoss(),
                            compute_grad=True)
                        KD_optimizer.step()
                        avg_origin_loss += KD_hard_loss / batch_num
                        writer.add_scalar("student_KD_loss", KD_hard_loss,
                                          state["step"])

                    ### save wav ####
                    if example_num % args.example_freq == 0:
                        input_centre = torch.mean(
                            x[0, :, student_KD.shapes["output_start_frame"]:
                              student_KD.shapes["output_end_frame"]],
                            0)  # Stereo not supported for logs yet

                        writer.add_audio("input:",
                                         input_centre,
                                         state["step"],
                                         sample_rate=args.sr)
                        writer.add_audio("pred:",
                                         torch.mean(KD_outputs[0], 0),
                                         state["step"],
                                         sample_rate=args.sr)
                        writer.add_audio("target",
                                         torch.mean(targets[0], 0),
                                         state["step"],
                                         sample_rate=args.sr)

                    state["step"] += 1
                    pbar.update(1)
            # VALIDATE
            val_loss, val_metrics = validate(args, student_KD, criterion,
                                             val_data)
            print("ori VALIDATION FINISHED: LOSS: " + str(val_loss))

            writer.add_scalar("avg_origin_loss", avg_origin_loss,
                              state["epochs"])
            writer.add_scalar("val_enhance_pesq", val_metrics[0],
                              state["epochs"])
            writer.add_scalar("val_improve_pesq", val_metrics[1],
                              state["epochs"])
            writer.add_scalar("val_enhance_stoi", val_metrics[2],
                              state["epochs"])
            writer.add_scalar("val_improve_stoi", val_metrics[3],
                              state["epochs"])
            writer.add_scalar("val_enhance_SISDR", val_metrics[4],
                              state["epochs"])
            writer.add_scalar("val_improve_SISDR", val_metrics[5],
                              state["epochs"])
            # writer.add_scalar("val_COPY_pesq",val_metrics_copy[0], state["epochs"])
            writer.add_scalar("val_loss", val_loss, state["epochs"])

            # Set up training state dict that will also be saved into checkpoints
            checkpoint_path = os.path.join(
                checkpoint_dir, "checkpoint_" + str(state["epochs"]))
            if val_metrics[0] < state["best_pesq"]:
                state["worse_epochs"] += 1
            else:
                print("MODEL IMPROVED ON VALIDATION SET!")
                state["worse_epochs"] = 0
                state["best_pesq"] = val_metrics[0]
                state["best_checkpoint"] = checkpoint_path

            # CHECKPOINT
            print("Saving model...")
            utils.save_model(student_KD, KD_optimizer, state, checkpoint_path)
            print('dump alpha_memory')
            with open(os.path.join(log_dir, 'alpha_' + str(state["epochs"])),
                      "wb") as fp:  #Pickling
                pickle.dump(memory_alpha, fp)
            state["epochs"] += 1
        writer.close()
        info = args.model_name
        path = os.path.join(result_dir, info)
    else:
        PATH = args.load_model.split("/")
        info = PATH[-3] + "_" + PATH[-1]
        if (args.outside_test == True):
            info += "_outside_test"
        print(info)
        path = os.path.join(result_dir, info)

    #### TESTING ####
    # Test loss
    print("TESTING")
    # eval metrics
    _ = utils.load_model(student_KD, KD_optimizer, state["best_checkpoint"],
                         args.cuda)
    test_metrics = evaluate(args, dataset["test"], student_KD)
    test_pesq = test_metrics['pesq']
    test_stoi = test_metrics['stoi']
    test_SISDR = test_metrics['SISDR']
    test_noise = test_metrics['noise']

    if not os.path.exists(path):
        os.makedirs(path)
    utils.save_result(test_pesq, path, "pesq")
    utils.save_result(test_stoi, path, "stoi")
    utils.save_result(test_SISDR, path, "SISDR")
    utils.save_result(test_noise, path, "noise")
Beispiel #19
0
class TacotronTrainer:
    TRAIN_STAGE = 'train'
    VAL_STAGE = 'val'
    VERSION_FORMAT = 'VERSION_{}'
    MODEL_SAVE_FORMAT = 'version_{version:03}_model_{step:010}.pth'

    def __init__(self,
                 batch_size: int = 32,
                 num_epoch: int = 100,
                 train_split: float = 0.9,
                 log_interval: int = 1000,
                 log_audio_factor: int = 5,
                 lr: float = 0.001,
                 num_data: int = None,
                 log_root: str = './tb_logs',
                 save_root: str = './checkpoints',
                 num_workers: int = 4,
                 version: int = None,
                 num_test_samples: int = 5):
        """
        Initialize tacotron trainer
        Args:
            batch_size: batch size
            num_epoch: total number of epochs to train
            train_split: train ratio of train-val split
            log_interval: interval for test sample logging to tensorboard in
                epoch unit
            log_audio_factor: number of log_interval for logging audio
                which requires quite a lot of overhead
            num_data: number of datapoints to load in the dataset
            log_root: root directory for the tensorboard logging
            save_root: root directory for saving model
            num_workers: number of workers for dataloader
            version: version of training
            num_test_samples: number of test samples to generate
                for each logging
        """
        if not os.path.exists(log_root):
            os.makedirs(log_root)

        if not os.path.exists(save_root):
            os.makedirs(save_root)

        is_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if is_cuda else 'cpu')

        self.train_split = train_split

        self.epoch_num = num_epoch

        self.splitted_dataset = self.__split_dataset(
            TorchLJSpeechDataset(num_data=num_data))

        self.dataloaders = self.__get_dataloaders(
            batch_size, num_workers=num_workers)

        self.tacotron = Tacotron()
        self.tacotron.to(self.device)

        self.loss = TacotronLoss()
        self.optimizer = Adam(self.tacotron.parameters(), lr=lr)
        self.lr_scheduler = StepLR(
            optimizer=self.optimizer,
            step_size=10000,
            gamma=0.9)

        if version is None:
            versions = os.listdir(log_root)
            if not versions:
                self.version = 0
            else:
                self.version = max([int(ver[-1]) for ver in versions]) + 1

        log_dir = os.path.join(
            log_root, self.VERSION_FORMAT.format(self.version))
        if os.path.exists(log_dir):
            os.remove(log_dir)

        self.logger = SummaryWriter(log_dir)

        self.save_root = save_root

        self.log_interval = log_interval
        self.log_audio_factor = log_audio_factor

        self.global_step = 0
        self.running_count = {self.TRAIN_STAGE: 0,
                              self.VAL_STAGE: 0}
        self.running_loss = {self.TRAIN_STAGE: 0,
                             self.VAL_STAGE: 0}

        self.sample_indices = list(range(num_test_samples))

    def fit_from_checkpoint(self, checkpoint_file: str):
        self.tacotron.load(checkpoint_file, self.device)
        self.fit()

    def fit(self):
        for epoch in tqdm.tqdm(range(self.epoch_num),
                               total=self.epoch_num,
                               desc='Epoch'):
            self.__run_epoch(epoch)

    def __run_epoch(self, epoch: int):
        # reset running loss and count after each epoch
        self.__reset_loss()
        self.__reset_count()

        for stage, dataloader in self.dataloaders.items():
            prog_bar = tqdm.tqdm(dataloader,
                                 desc=f'{stage.capitalize()} in progress',
                                 total=len(dataloader))
            for batch in dataloader:
                self.__run_step(batch, stage, prog_bar)

        # epoch vs global step
        self.logger.add_scalar('epoch', epoch, global_step=self.global_step)

        # add loss to logger
        loss_dict = {stage: self.__calculate_mean_loss(stage)
                     for stage in self.running_loss}
        self.logger.add_scalars('loss', loss_dict, global_step=epoch)


    def __run_step(self, batch: TorchLJSpeechBatch, stage: str,
                   prog_bar: tqdm.tqdm):
        if stage == self.TRAIN_STAGE:
            self.tacotron.train()
            self.optimizer.zero_grad()
        else:
            self.tacotron.eval()

        batch = batch.to(self.device)

        output = self.tacotron.forward_train(batch)
        loss_val = self.loss(batch.mel_spec, output.pred_mel_spec,
                             batch.lin_spec, output.pred_lin_spec)

        self.running_loss[stage] += loss_val.item() * batch.mel_spec.size(0)
        self.running_count[stage] += batch.mel_spec.size(0)

        if stage == self.TRAIN_STAGE:
            loss_val.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            if self.global_step % self.log_interval == 0:
                self.logger.add_scalar('training_loss',
                                       self.__calculate_mean_loss(stage),
                                       global_step=self.global_step)
                log_audio = False
                if self.global_step % (self.log_interval * self.log_audio_factor) == 0:
                    log_audio = True
                sample_results = self.__get_sample_results()
                for sample_result in sample_results:
                    self.__log_sample_results(
                        self.global_step, sample_result, log_audio=log_audio)

                self.tacotron.train()
                save_file = os.path.join(
                    self.save_root,
                    self.MODEL_SAVE_FORMAT.format(
                        version=self.version, step=self.global_step)
                )
                torch.save(self.tacotron.state_dict(), save_file)

            self.global_step += 1

        prog_bar.update()
        prog_bar.set_postfix(
            {'Running Loss': f'{self.__calculate_mean_loss(stage):.3f}'})

    def __log_sample_results(self, steps: int,
                             sample_result: SampleResult,
                             log_mel: bool = True,
                             log_spec: bool = True,
                             log_attention: bool = True,
                             log_audio: bool = True) -> None:
        """
        Log the sample results into tensorboard
        Args:
            steps: current step
            sample_result: sample result to log
            log_mel: if True, log mel spectrogram
            log_spec: if True, log spectrogram
            log_attention: if True, log attention
            log_audio: if True, log audio

        """
        if log_mel:
            title = f'Log Mel Spectrogram, Step:{steps}, ' \
                    f'Uid: {sample_result.uid}'

            fig = self.__get_spec_plot(
                pred_spec=sample_result.pred_mel_spec,
                truth_spec=sample_result.truth_mel_spec,
                suptitle=title,
                ylabel='Mel')
            img_tensor = self.__get_plot_tensor(fig)
            tag = f'mel_spec/{sample_result.uid}'
            self.logger.add_image(tag, img_tensor, global_step=steps)

        if log_spec:
            title = f'Log Spectrogram, Step:{steps}, ' \
                    f'Uid: {sample_result.uid}'
            fig = self.__get_spec_plot(
                pred_spec=sample_result.pred_lin_spec,
                truth_spec=sample_result.truth_lin_spec,
                suptitle=title,
                ylabel='DFT bins')
            img_tensor = self.__get_plot_tensor(fig)
            tag = f'lin_spec/{sample_result.uid}'
            self.logger.add_image(tag, img_tensor, global_step=steps)

        if log_attention:
            title = f'Attention Weight, Epoch :{steps}, ' \
                    f'Uid: {sample_result.uid}'
            fig = self.__get_attention_plot(
                title=title,
                attention_weight=sample_result.attention_weight)
            img_tensor = self.__get_plot_tensor(fig)
            tag = f'attention/{sample_result.uid}'
            self.logger.add_image(tag, img_tensor, global_step=steps)

        if log_audio:
            pred_tag = f'audio/{sample_result.uid}_predicted'
            truth_tag = f'audio/{sample_result.uid}_truth'

            self.logger.add_audio(
                tag=pred_tag,
                snd_tensor=torch.from_numpy(
                    sample_result.pred_audio).unsqueeze(1),  # add channel dim
                global_step=steps,
                sample_rate=AudioProcessParam.sr
            )

            self.logger.add_audio(
                tag=truth_tag,
                snd_tensor=torch.from_numpy(
                    sample_result.truth_audio).unsqueeze(1),  # add channel dim
                global_step=steps,
                sample_rate=AudioProcessParam.sr
            )

    def __get_sample_results(self) -> List[SampleResult]:
        """
        Get sample results to show in tensorboard, including
            1. Predicted and ground truth spectrogram pairs
            2. Predicted and ground truth mel spectrogram pairs
            3. Predicted and ground truth audio pairs
            4. Attention weight
        Returns:
            list of sample results

        """
        val_dataset = self.splitted_dataset[self.VAL_STAGE]
        self.tacotron.eval()

        test_insts = []
        with torch.no_grad():
            for subset_i in self.sample_indices:
                datapoint: TorchLJSpeechData = val_dataset[subset_i]
                datapoint: TorchLJSpeechBatch = datapoint.add_batch_dim()
                datapoint = datapoint.to(self.device)

                ds_idx = val_dataset.indices[subset_i]
                uid = val_dataset.dataset.uids[ds_idx]

                # Transcription
                transcription = val_dataset.dataset.uid_to_transcription[uid]

                wav_filepath = os.path.join(
                    val_dataset.dataset.wav_save_dir, f'{uid}.wav')
                truth_audio = AudioProcessingHelper.load_audio(wav_filepath)

                taco_output = self.tacotron.forward_train(datapoint)

                spec = taco_output.pred_lin_spec.squeeze(0).cpu().numpy().T
                pred_audio = AudioProcessingHelper.spec2audio(spec)

                test_insts.append(
                    SampleResult(
                        uid=uid,
                        transcription=transcription,
                        truth_lin_spec=datapoint.lin_spec.squeeze(0).cpu().numpy().T,
                        pred_lin_spec=taco_output.pred_lin_spec.squeeze(0).cpu().numpy().T,
                        truth_mel_spec=datapoint.mel_spec.squeeze(0).cpu().numpy().T,
                        pred_mel_spec=taco_output.pred_mel_spec.squeeze(0).cpu().numpy().T,
                        attention_weight=taco_output.attention_weight.squeeze(0).cpu().numpy(),
                        truth_audio=truth_audio,
                        pred_audio=pred_audio
                    )
                )

        return test_insts

    @staticmethod
    def __get_attention_plot(
            title: str, attention_weight: np.ndarray) -> plt.Figure:
        """
        Get figure handle for attention plot

        Args:
            title: title of the plot
            attention_weight: attention weight to plot

        Returns:
            figure object

        """
        fig = plt.figure(figsize=(6, 5), dpi=80)
        plt.title(title)
        plt.imshow(attention_weight, aspect='auto')
        plt.colorbar()
        plt.xlabel('Encoder seq')
        plt.ylabel('Decoder seq')
        plt.gca().invert_yaxis()  # Let the x, y axis start from the left-bottom corner
        plt.close(fig)
        return fig

    @staticmethod
    def __get_spec_plot(pred_spec: np.ndarray, truth_spec: np.ndarray,
                        suptitle: str, ylabel: str) -> plt.Figure:
        """
        Get a juxtaposition two spectrograms with appropriate title
        Args:
            pred_spec: predicted spectrogram
            truth_spec: ground truth spectrogram
            suptitle: title of the plot
            ylabel: unit of frequency axis of the spectrograms

        Returns:
            figure object

        """
        vmin = min(np.min(truth_spec), np.min(pred_spec))
        vmax = max(np.max(truth_spec), np.max(pred_spec))

        fig = plt.figure(figsize=(11, 5), dpi=80)
        plt.suptitle(suptitle)

        ax1 = plt.subplot(121)
        plt.title('Ground Truth')
        plt.xlabel('Frame')
        plt.ylabel(ylabel)
        plt.imshow(truth_spec, vmin=vmin, vmax=vmax, aspect='auto')
        plt.gca().invert_yaxis()  # let the x, y axis start from the left-bottom corner

        ax2 = plt.subplot(122)
        plt.title('Predicted')
        plt.xlabel('Frame')
        im = plt.imshow(pred_spec, vmin=vmin, vmax=vmax, aspect='auto')
        plt.gca().invert_yaxis()  # let the x, y axis start from the left-bottom corner

        fig.tight_layout()
        fig.colorbar(im, ax=[ax1, ax2])
        plt.close(fig)

        return fig

    @staticmethod
    def __get_plot_tensor(fig) -> torch.Tensor:
        """
        Get tensor for the given figure object
        Args:
            fig: the figure object to convert into tensor

        Returns:
            tensor of the figure

        """
        buf = io.BytesIO()
        fig.savefig(buf, format='jpeg')
        buf.seek(0)
        image = PIL.Image.open(buf)
        image = ToTensor()(image)
        return image

    def __calculate_mean_loss(self, stage: str) -> float:
        """
        Calculate mean loss for given stage (train/val)
        Args:
            stage: train/val

        Returns:
            mean loss

        """
        return self.running_loss[stage] / self.running_count[stage]

    def __reset_loss(self) -> None:
        self.running_loss = {self.TRAIN_STAGE: 0,
                             self.VAL_STAGE: 0}

    def __reset_count(self) -> None:
        self.running_count = {self.TRAIN_STAGE: 0,
                              self.VAL_STAGE: 0}

    def __split_dataset(self, dataset: TorchLJSpeechDataset) -> Dict[str, Subset]:
        """
        Split the dataset into train/validation set
        Args:
            dataset: dataset to split

        Returns:
            splitted dataset

        """
        num_train_data = int(len(dataset) * self.train_split)
        num_val_data = len(dataset) - num_train_data
        train_dataset, val_dataset = random_split(
            dataset, [num_train_data, num_val_data])

        return {self.TRAIN_STAGE: train_dataset,
                self.VAL_STAGE: val_dataset}

    def __get_dataloaders(
            self, batch_size: int, num_workers: int) -> Dict[str, DataLoader]:
        return {stage: DataLoader(
            dataset, shuffle=(stage == self.TRAIN_STAGE),
            collate_fn=TorchLJSpeechDataset.batch_tacotron_input,
            pin_memory=True, batch_size=batch_size,
            num_workers=num_workers)
            for stage, dataset in self.splitted_dataset.items()
        }
Beispiel #20
0
class Logger:
    _count = 0

    def __init__(self, scrn=True, log_dir='', phase=''):
        super().__init__()
        self._logger = logging.getLogger('logger_{}'.format(Logger._count))
        Logger._count += 1
        self._logger.setLevel(logging.DEBUG)

        if scrn:
            self._scrn_handler = logging.StreamHandler()
            self._scrn_handler.setLevel(logging.INFO)
            self._scrn_handler.setFormatter(
                logging.Formatter(fmt=FORMAT_SHORT))
            self._logger.addHandler(self._scrn_handler)

        if log_dir and phase:
            self.log_path = os.path.join(
                log_dir,
                '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format(
                    phase,
                    *localtime()[:6]))
            self.show_nl("log into {}\n\n".format(self.log_path))
            self._file_handler = logging.FileHandler(filename=self.log_path)
            self._file_handler.setLevel(logging.DEBUG)
            self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
            self._logger.addHandler(self._file_handler)

            self._writer = SummaryWriter(log_dir=os.path.join(
                log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format(
                    phase,
                    *localtime()[:6])))

    def show(self, *args, **kwargs):
        return self._logger.info(*args, **kwargs)

    def show_nl(self, *args, **kwargs):
        self._logger.info("")
        return self.show(*args, **kwargs)

    def dump(self, *args, **kwargs):
        return self._logger.debug(*args, **kwargs)

    def warning(self, *args, **kwargs):
        return self._logger.warning(*args, **kwargs)

    def error(self, *args, **kwargs):
        return self._logger.error(*args, **kwargs)

    # tensorboard
    def add_scalar(self, *args, **kwargs):
        return self._writer.add_scalar(*args, **kwargs)

    def add_scalars(self, *args, **kwargs):
        return self._writer.add_scalars(*args, **kwargs)

    def add_histogram(self, *args, **kwargs):
        return self._writer.add_histogram(*args, **kwargs)

    def add_image(self, *args, **kwargs):
        return self._writer.add_image(*args, **kwargs)

    def add_images(self, *args, **kwargs):
        return self._writer.add_images(*args, **kwargs)

    def add_figure(self, *args, **kwargs):
        return self._writer.add_figure(*args, **kwargs)

    def add_video(self, *args, **kwargs):
        return self._writer.add_video(*args, **kwargs)

    def add_audio(self, *args, **kwargs):
        return self._writer.add_audio(*args, **kwargs)

    def add_text(self, *args, **kwargs):
        return self._writer.add_text(*args, **kwargs)

    def add_graph(self, *args, **kwargs):
        return self._writer.add_graph(*args, **kwargs)

    def add_pr_curve(self, *args, **kwargs):
        return self._writer.add_pr_curve(*args, **kwargs)

    def add_custom_scalars(self, *args, **kwargs):
        return self._writer.add_custom_scalars(*args, **kwargs)

    def add_mesh(self, *args, **kwargs):
        return self._writer.add_mesh(*args, **kwargs)

    # def add_hparams(self, *args, **kwargs):
    #     return self._writer.add_hparams(*args, **kwargs)

    def flush(self):
        return self._writer.flush()

    def close(self):
        return self._writer.close()

    def _grad_hook(self, grad, name=None, grads=None):
        grads.update({name: grad})

    def watch_grad(self, model, layers):
        """
        Add hooks to the specific layers. Gradients of these layers will save to self.grads
        :param model:
        :param layers: Except a list eg. layers=[0, -1] means to watch the gradients of
                        the fist layer and the last layer of the model
        :return:
        """
        assert layers
        if not hasattr(self, 'grads'):
            self.grads = {}
            self.grad_hooks = {}
        named_parameters = list(model.named_parameters())
        for layer in layers:
            name = named_parameters[layer][0]
            handle = named_parameters[layer][1].register_hook(
                functools.partial(self._grad_hook, name=name,
                                  grads=self.grads))
            self.grad_hooks.update(dict(name=handle))

    def watch_grad_close(self):
        for _, handle in self.grad_hooks.items():
            handle.remove()  # remove the hook

    def add_grads(self, global_step=None, *args, **kwargs):
        """
        Add gradients to tensorboard. You must call the method self.watch_grad before using this method!
        """
        assert  hasattr(self, 'grads'),\
        "self.grads is nonexisent! You must call self.watch_grad before!"
        assert self.grads, "self.grads if empty!"
        for (name, grad) in self.grads.items():
            self.add_histogram(tag=name,
                               values=grad,
                               global_step=global_step,
                               *args,
                               **kwargs)

    @staticmethod
    def make_desc(counter, total, *triples):
        desc = "[{}/{}]".format(counter, total)
        # The three elements of each triple are
        # (name to display, AverageMeter object, formatting string)
        for name, obj, fmt in triples:
            desc += (" {} {obj.val:" + fmt + "} ({obj.avg:" + fmt +
                     "})").format(name, obj=obj)
        return desc
Beispiel #21
0
def main():
    args = parse_args()

    root = Path(args.save_path)
    spec_type = args.spectrogram
    inference = args.inference
    load_root = Path(args.load_path) if args.load_path else None
    root.mkdir(parents=True, exist_ok=True)
    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(root / "args.yml", "w") as f:
        yaml.dump(args, f)
    writer = SummaryWriter(str(root))

    #######################
    # Load PyTorch Models #
    #######################
    if (spec_type == "mel"):
        print('Mel spec selected')
        netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers
                         ).cuda()  #initialise generator with n mel channels
        fft = Audio2Mel(n_mel_channels=args.n_mel_channels).cuda()
    if (spec_type == "cqt"):
        print('CQT spec selected')
        netG = Generator(
            args.n_bins, args.ngf,
            args.n_residual_layers).cuda()  #initialise generator with n bins
        fft = Audio2Cqt(n_bins=args.n_bins).cuda()
    else:
        print('No spectrogram specified, defaulting to CQT')
        netG = Generator(args.n_bins, args.ngf, args.n_residual_layers).cuda()
        fft = Audio2Cqt(n_bins=args.n_bins).cuda()
    netD = Discriminator(
        args.num_D, args.ndf, args.n_layers_D,
        args.downsamp_factor).cuda()  #initialize discriminator

    print(netG)
    print(netD)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

    if load_root and load_root.exists():
        print('Loading weights')
        netG.load_state_dict(torch.load(load_root / "netG.pt"))
        optG.load_state_dict(torch.load(load_root / "optG.pt"))
        netD.load_state_dict(torch.load(load_root / "netD.pt"))
        optD.load_state_dict(torch.load(load_root / "optD.pt"))
        print('Weights loaded')

    #######################
    # Create data loaders #
    #######################
    if inference:
        print('Starting inference')
        st = time.time()
        with torch.no_grad():
            x = torch.load('unseen.pt')
            x = x.cuda()
            pred_audio = netG(x)
            pred_audio = pred_audio.squeeze().cpu()
            save_sample(root / ("generated_sample.wav"), 22050, pred_audio)
            writer.add_audio(
                "generated/sample_test.wav",
                pred_audio,
                global_step=None,
                sample_rate=22050,
            )
            print('Finished inference')
        return

    train_set = AudioDataset(Path(args.data_path) / "train_files.txt",
                             args.seq_len,
                             sampling_rate=22050)

    test_set = AudioDataset(Path(args.data_path) / "test_files.txt",
                            22050 * 4,
                            sampling_rate=22050,
                            augment=False)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=4)
    test_loader = DataLoader(test_set, batch_size=1)

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.cuda()
        s_t = fft(x_t).detach()

        test_voc.append(s_t.cuda())
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        save_sample(root / ("original_%d.wav" % i), 22050, audio)
        writer.add_audio("original/sample_%d.wav" % i,
                         audio,
                         0,
                         sample_rate=22050)

        if i == args.n_test_samples - 1:
            break
    costs = []
    start = time.time()
    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000
    steps = 0
    for epoch in range(1, args.epochs + 1):
        for iterno, x_t in enumerate(train_loader):
            x_t = x_t.cuda()
            s_t = fft(x_t).detach()  # generate spectrogram
            x_pred_t = netG(s_t.cuda())  # generate audio from real spectrogram
            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach()
                               )  # get spectrogram from the audio we generated
                s_error = F.l1_loss(s_t, s_pred_t).item(
                )  # find loss between generated spectrogram from real audio and the spectrogram from audio by the generator
            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.cuda().detach())
            D_real = netD(x_t.cuda())

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.cuda())

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (args.n_layers_D + 1)
            D_weights = 1.0 / args.num_D
            wt = D_weights * feat_weights
            for i in range(args.num_D):
                for j in range(len(D_fake[i]) - 1):
                    loss_feat += wt * F.l1_loss(D_fake[i][j],
                                                D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + args.lambda_feat * loss_feat).backward()
            optG.step()

            ######################
            # Update tensorboard #
            ######################
            costs.append(
                [loss_D.item(),
                 loss_G.item(),
                 loss_feat.item(), s_error])

            writer.add_scalar("loss/discriminator", costs[-1][0], steps)
            writer.add_scalar("loss/generator", costs[-1][1], steps)
            writer.add_scalar("loss/feature_matching", costs[-1][2], steps)
            writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps)
            steps += 1

            if steps % args.save_interval == 0:
                st = time.time()
                with torch.no_grad():
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        save_sample(root / ("generated_%d.wav" % i), 22050,
                                    pred_audio)
                        writer.add_audio(
                            "generated/sample_%d.wav" % i,
                            pred_audio,
                            epoch,
                            sample_rate=22050,
                        )

                torch.save(netG.state_dict(), root / "netG.pt")
                torch.save(optG.state_dict(), root / "optG.pt")

                torch.save(netD.state_dict(), root / "netD.pt")
                torch.save(optD.state_dict(), root / "optD.pt")

                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netD.state_dict(), root / "best_netD.pt")
                    torch.save(netG.state_dict(), root / "best_netG.pt")

                print("Took %5.4fs to generate samples" % (time.time() - st))
                print("-" * 100)

            if steps % args.log_interval == 0:
                print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".
                      format(
                          epoch,
                          iterno,
                          len(train_loader),
                          1000 * (time.time() - start) / args.log_interval,
                          np.asarray(costs).mean(0),
                      ))
                costs = []
                start = time.time()
Beispiel #22
0
def train_fn(args, params):
    # Directory preparation
    exp_dir = makeExpDirs(args.results_dir, args.exp_name)

    # Automatic Mixed-Precision
    if args.optim != "no":
        import apex

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Vocoder(mel_channels=params["preprocessing"]["num_mels"],
                    conditioning_channels=params["vocoder"]["conditioning_channels"],
                    embedding_dim=params["vocoder"]["embedding_dim"],
                    rnn_channels=params["vocoder"]["rnn_channels"],
                    fc_channels=params["vocoder"]["fc_channels"],
                    bits=params["preprocessing"]["bits"],
                    hop_length=params["preprocessing"]["hop_length"],
                    nc=args.nc,
                    device=device
                    )
    model.to(device)
    print(model)

    optimizer = optim.Adam(model.parameters(), lr=params["vocoder"]["learning_rate"])

    # Automatic Mixed-Precision
    if args.optim != "no":
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level=args.optim)

    scheduler = optim.lr_scheduler.StepLR(optimizer, params["vocoder"]["schedule"]["step_size"], params["vocoder"]["schedule"]["gamma"])

    if args.resume is not None:
        print(f"Resume checkpoint from: {args.resume}:")
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    train_dataset = VocoderDataset(meta_file=os.path.join(args.data_dir, "train.txt"),
                                   sample_frames=params["vocoder"]["sample_frames"],
                                   audio_slice_frames=params["vocoder"]["audio_slice_frames"],
                                   hop_length=params["preprocessing"]["hop_length"],
                                   bits=params["preprocessing"]["bits"])

    train_dataloader = DataLoader(train_dataset, batch_size=params["vocoder"]["batch_size"],
                                  shuffle=True, num_workers=1,
                                  pin_memory=True)

    num_epochs = params["vocoder"]["num_steps"] // len(train_dataloader) + 1
    start_epoch = global_step // len(train_dataloader) + 1

    # Logger
    writer = SummaryWriter(exp_dir/"logs")

    # Add original utterance to TensorBoard
    if args.resume is None:
        with open(os.path.join(args.data_dir, "test.txt"), encoding="utf-8") as f:
            test_wavnpy_paths = [line.strip().split("|")[1] for line in f]
        for index, wavnpy_path in enumerate(test_wavnpy_paths):
            muraw_npy = np.load(wavnpy_path)
            wav_npy = mulaw_decode(muraw_npy, 2**params["preprocessing"]["bits"])
            writer.add_audio("orig", torch.from_numpy(wav_npy), global_step=global_step, sample_rate=params["preprocessing"]["sample_rate"])
            break


    for epoch in range(start_epoch, num_epochs + 1):
        running_loss = 0
        
        for i, (audio, mels) in enumerate(tqdm(train_dataloader, leave=False), 1):
            audio, mels = audio.to(device), mels.to(device)

            output = model(audio[:, :-1], mels)
            loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:])
            optimizer.zero_grad()

            # Automatic Mixed-Precision
            if args.optim != "no":
                with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            average_loss = running_loss / i

            global_step += 1

            if global_step % args.save_step == 0:
                save_checkpoint(model, optimizer, scheduler, global_step, exp_dir/"params", False)

            if global_step % params["vocoder"]["checkpoint_interval"] == 0:
                save_checkpoint(model, optimizer, scheduler, global_step, exp_dir/"params", True)

            if global_step % params["vocoder"]["generation_interval"] == 0:
                with open(os.path.join(args.data_dir, "test.txt"), encoding="utf-8") as f:
                    test_mel_paths = [line.strip().split("|")[2] for line in f]

                for index, mel_path in enumerate(test_mel_paths):
                    utterance_id = os.path.basename(mel_path).split(".")[0]
                    # unsqueeze: insert in a batch
                    mel = torch.FloatTensor(np.load(mel_path)).unsqueeze(0).to(device)
                    output = model.generate(mel)
                    path = exp_dir/"samples"/f"gen_{utterance_id}_model_steps_{global_step}.wav"
                    save_wav(str(path), output, params["preprocessing"]["sample_rate"])
                    if index == 0:
                        writer.add_audio("cnvt", torch.from_numpy(output), global_step=global_step, sample_rate=params["preprocessing"]["sample_rate"])
        # finish a epoch
        writer.add_scalar("NLL", average_loss, global_step)
Beispiel #23
0
class Logger(object):
    def __init__(self, config, rank=0):
        self.rank = rank
        self.summary_writer = None

        if self.rank == 0:
            self.logdir = config.training_config.logdir
            self.continue_training = config.training_config.continue_training
            self.sample_rate = config.data_config.sample_rate

            if rank == 0 and not self.continue_training and os.path.exists(
                    self.logdir):
                raise RuntimeError(
                    f"You're trying to run training from scratch, "
                    f"but logdir `{self.logdir} already exists. Remove it or specify new one.`"
                )
            if rank == 0 and not self.continue_training:
                os.makedirs(self.logdir)
            self.summary_writer = SummaryWriter(config.training_config.logdir)
            self.save_model_config(config)

    def _log_losses(self, iteration, loss_stats: dict):
        for key, value in loss_stats.items():
            self.summary_writer.add_scalar(key, value, iteration)

    def log_training(self, iteration, stats, verbose=True):
        if self.rank != 0: return
        stats = {f'training/{key}': value for key, value in stats.items()}
        self._log_losses(iteration, loss_stats=stats)
        show_message(
            f'Iteration: {iteration} | Losses: {[value for value in stats.values()]}',
            verbose=verbose)

    def log_test(self, epoch, stats, verbose=True):
        if self.rank != 0: return
        stats = {f'test/{key}': value for key, value in stats.items()}
        self._log_losses(epoch, loss_stats=stats)
        show_message(
            f'Epoch: {epoch} | Losses: {[value for value in stats.values()]}',
            verbose=verbose)

    def log_audios(self, iteration, audios: dict):
        if self.rank != 0: return
        for key, audio in audios.items():
            self.summary_writer.add_audio(key,
                                          audio,
                                          iteration,
                                          sample_rate=self.sample_rate)

    def log_specs(self, iteration, specs: dict):
        if self.rank != 0: return
        for key, image in specs.items():
            self.summary_writer.add_image(key,
                                          plot_tensor_to_numpy(image),
                                          iteration,
                                          dataformats='HWC')

    def save_model_config(self, config):
        if self.rank != 0: return
        with open(f'{self.logdir}/config.json', 'w') as f:
            json.dump(config.to_dict_type(), f)

    def save_checkpoint(self, iteration, model, optimizer=None):
        if self.rank != 0: return
        d = {}
        d['iteration'] = iteration
        d['model'] = model.state_dict()
        if not isinstance(optimizer, type(None)):
            d['optimizer'] = optimizer.state_dict()
        filename = f'{self.summary_writer.log_dir}/checkpoint_{iteration}.pt'
        torch.save(d, filename)

    def load_latest_checkpoint(self, model, optimizer=None):
        if not self.continue_training:
            raise RuntimeError(
                f"Trying to load the latest checkpoint from logdir {self.logdir}, "
                "but did not set `continue_training=true` in configuration.")
        model, optimizer, iteration = load_latest_checkpoint(
            self.logdir, model, optimizer)
        return model, optimizer, iteration
Beispiel #24
0
    frame_sr=dataset.frame_sr,
    dtype=dtype,
)

training = Training(model, dataset, batch_size=batch_size)

# plot ref audio and waveform
ref_audio, ref_f0, ref_f0_scaled, ref_lo = dataset.get_sample(0)
ref_audio = ref_audio.numpy()
# max 5 secs
ref_audio = ref_audio[:dataset.audio_sr * 5]
ref_f0 = ref_f0[:dataset.frame_sr * 5]
ref_f0_scaled = ref_f0_scaled[:dataset.frame_sr * 5]
ref_lo = ref_lo[:dataset.frame_sr * 5]

writer.add_audio("Groundtruth", ref_audio, 0, dataset.audio_sr)
fig = plot_wf(ref_audio, dataset.audio_sr)
writer.add_figure("Waveform/Groundtruth", fig, 0, True)
fig = plot_spectrogram(ref_audio, dataset.audio_sr)
writer.add_figure("Spectrogram/Groundtruth", fig, 0, True)
writer.flush()

best_epoch_loss = torch.finfo(dtype).max

while training.epoch < nb_epochs:
    # run training and test epoch
    training.train_epoch()
    training.test_epoch()

    # plot losses
    writer.add_scalar("Loss/Train", training.train_loss, training.epoch)
class TrainingProcessHandler(object):
    def __init__(self,
                 data_folder="logs",
                 model_folder="model",
                 enable_iteration_progress_bar=False,
                 model_save_key="loss",
                 mlflow_tags=None,
                 mlflow_parameters=None,
                 enable_mlflow=True,
                 mlflow_experiment_name="DDT"):
        if mlflow_tags is None:
            mlflow_tags = {}
        if mlflow_parameters is None:
            mlflow_parameters = {}
        self._name = None
        self._epoch_count = 0
        self._iteration_count = 0
        self._current_epoch = 0
        self._current_iteration = 0
        self._log_folder = data_folder
        self._iteration_progress_bar = None
        self._enable_iteration_progress_bar = enable_iteration_progress_bar
        self._epoch_progress_bar = None
        self._writer = None
        self._train_metrics = {}
        self._model = None
        self._model_folder = model_folder
        self._run_name = ""
        self.train_history = {}
        self.validation_history = {}
        self._model_save_key = model_save_key
        self._previous_model_save_metric = None
        if not os.path.exists(self._model_folder):
            os.mkdir(self._model_folder)
        if not os.path.exists(self._log_folder):
            os.mkdir(self._log_folder)
        self._audio_configs = {}
        self._global_epoch_step = 0
        self._global_iteration_step = 0
        if enable_mlflow:
            self._mlflow_handler = MlFlowHandler(
                experiment_name=mlflow_experiment_name,
                mlflow_tags=mlflow_tags,
                mlflow_parameters=mlflow_parameters)
        else:
            self._mlflow_handler = None
        self._artifacts = []

    def setup_handler(self, name, model):
        self._name = name
        self._run_name = name + "_" + datetime.datetime.now().strftime(
            '%Y-%m-%d-%H-%M-%S')
        self._writer = SummaryWriter(
            os.path.join(self._log_folder, self._run_name))
        self._model = model
        self._previous_model_save_metric = None
        self.train_history = {}
        self.validation_history = {}
        self._global_epoch_step = 0
        self._global_iteration_step = 0

    def start_callback(self, epoch_count, iteration_count, parameters=None):
        if parameters is None:
            parameters = {}
        self._epoch_count = epoch_count
        self._iteration_count = iteration_count
        self._current_epoch = 0
        self._current_iteration = 0
        self._epoch_progress_bar = tqdm(total=self._epoch_count)
        if self._enable_iteration_progress_bar:
            self._iteration_progress_bar = tqdm(total=self._iteration_count //
                                                self._epoch_count)
        if self._mlflow_handler is not None:
            self._mlflow_handler.start_callback(parameters)

    def epoch_callback(self,
                       metrics,
                       image_batches=None,
                       figures=None,
                       audios=None,
                       texts=None):
        self._artifacts = []
        for key, value in metrics.items():
            self.validation_history.setdefault(key, []).append(value)
        self._write_epoch_metrics(metrics)
        if image_batches is not None:
            self._write_image_batches(image_batches)
        if figures is not None:
            self._write_figures(figures)
        if audios is not None:
            self._write_audios(audios)
        if texts is not None:
            self._write_texts(texts)
        if self._enable_iteration_progress_bar and self._epoch_count != self._current_epoch - 1:
            self._iteration_progress_bar.reset()
        self._epoch_progress_bar.update()
        self._epoch_progress_bar.set_postfix_str(
            self.metric_string("valid", metrics))
        if self.should_save_model(metrics) and self._model is not None:
            torch.save(
                self._model.state_dict(),
                os.path.join(self._model_folder,
                             f"{self._run_name}_checkpoint.pth"))
        self._current_epoch += 1
        self._global_epoch_step += 1
        if self._mlflow_handler is not None:
            self._mlflow_handler.epoch_callback(metrics, self._current_epoch,
                                                self._artifacts)

    def iteration_callback(self, metrics):
        for key, value in metrics.items():
            self.train_history.setdefault(key, []).append(value)
        self._train_metrics = metrics
        self._write_iteration_metrics(metrics)
        if self._enable_iteration_progress_bar:
            self._iteration_progress_bar.set_postfix_str(
                self.metric_string("train", metrics))
            self._iteration_progress_bar.update()
        self._current_iteration += 1
        self._global_iteration_step += 1

    def finish_callback(self, metrics):
        print(self.metric_string("test", metrics))
        self._writer.close()
        if self._enable_iteration_progress_bar:
            self._iteration_progress_bar.close()
        self._epoch_progress_bar.close()
        if self._mlflow_handler is not None:
            self._mlflow_handler.finish_callback()

    @staticmethod
    def metric_string(prefix, metrics):
        result = ""
        for key, value in metrics.items():
            result += "{} {} = {:>3.3f}, ".format(prefix, key, value)
        return result[:-2]

    def _write_epoch_metrics(self, validation_metrics):
        for key, value in validation_metrics.items():
            self._writer.add_scalar(f"epoch/{key}",
                                    value,
                                    global_step=self._global_epoch_step)

    def _write_iteration_metrics(self, train_metrics):
        for key, value in train_metrics.items():
            self._writer.add_scalar(f"iteration/{key}",
                                    value,
                                    global_step=self._global_iteration_step)

    def should_save_model(self, metrics):
        if self._model_save_key not in metrics.keys():
            return True
        if self._previous_model_save_metric is None:
            self._previous_model_save_metric = metrics[self._model_save_key]
            return True
        if self._previous_model_save_metric > metrics[self._model_save_key]:
            self._previous_model_save_metric = metrics[self._model_save_key]
            return True
        return False

    def _write_image_batches(self, image_batches):
        for key, value in image_batches.items():
            self._writer.add_images(key,
                                    value,
                                    self._global_epoch_step,
                                    dataformats="NHWC")

    def _write_figures(self, figures):
        for key, value in figures.items():
            self._writer.add_figure(key, value, self._global_epoch_step)
            artifact_name = f"{self._log_folder}/{key}_{self._global_epoch_step:04d}.png"
            value.savefig(artifact_name)
            self._artifacts.append(artifact_name)

    def _write_audios(self, audios):
        for key, value in audios.items():
            self._writer.add_audio(key, value, self._global_epoch_step,
                                   **self._audio_configs)

    def set_audio_configs(self, configs):
        self._audio_configs = configs

    def _write_texts(self, texts):
        for key, value in texts.items():
            self._writer.add_text(key, value, self._global_epoch_step)
Beispiel #26
0
def trainG(args):
    root = Path(args['logging']['save_path'])
    load_root = Path(
        args['logging']['load_path']) if args['logging']['load_path'] else None
    root.mkdir(parents=True, exist_ok=True)
    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(root / "args.yml", "w") as f:
        yaml.dump(args, f)
    writer = SummaryWriter(str(root))

    #######################
    # Load PyTorch Models #
    #######################
    netG = Generator(args['fft']['n_mel_channels'],
                     args['Generator']['ngf'],
                     args['Generator']['n_residual_layers'],
                     ratios=args['Generator']['ratios']).cuda()

    if 'G_path' in args['Generator'] and args['Generator'][
            'G_path'] is not None:
        netG.load_state_dict(
            torch.load(args['Generator']['G_path'] / "netG.pt"))
    fft = Audio2Mel(n_mel_channels=args['fft']['n_mel_channels'],
                    n_fft=args['fft']['n_fft'],
                    hop_length=args['fft']['hop_length'],
                    win_length=args['fft']['win_length'],
                    sampling_rate=args['data']['sampling_rate'],
                    mel_fmin=args['fft']['mel_fmin'],
                    mel_fmax=args['fft']['mel_fmax']).cuda()

    print(netG)
    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(),
                            lr=args['optimizer']['lrG'],
                            betas=args['optimizer']['betasG'])

    if load_root and load_root.exists():
        netG.load_state_dict(torch.load(load_root / "netG.pt"))
        optG.load_state_dict(torch.load(load_root / "optG.pt"))
        print('checkpoints loaded')

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(Path(args['data']['data_path']) /
                             "train_files_inv.txt",
                             segment_length=args['data']['seq_len'],
                             sampling_rate=args['data']['sampling_rate'],
                             augment=['amp', 'flip', 'neg'])

    test_set = AudioDataset(Path(args['data']['data_path']) /
                            "test_files_inv.txt",
                            segment_length=args['data']['sampling_rate'] * 4,
                            sampling_rate=args['data']['sampling_rate'],
                            augment=None)

    train_loader = DataLoader(train_set,
                              batch_size=args['dataloader']['batch_size'],
                              num_workers=4,
                              pin_memory=True,
                              shuffle=True)
    test_loader = DataLoader(test_set, batch_size=1)

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.cuda()
        s_t = fft(x_t).detach()

        test_voc.append(s_t.cuda())
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        save_sample(root / ("original_%d.wav" % i),
                    args['data']['sampling_rate'], audio)
        writer.add_audio("original/sample_%d.wav" % i,
                         audio,
                         0,
                         sample_rate=args['data']['sampling_rate'])

        if i == args['logging']['n_test_samples'] - 1:
            break

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000
    steps = 0
    mr_stft_loss = MultiResolutionSTFTLoss().cuda()
    for epoch in range(1, args['train']['epochs'] + 1):
        for iterno, x_t in enumerate(train_loader):
            x_t = x_t.cuda()
            s_t = fft(x_t).detach()
            x_pred_t = netG(s_t.cuda())

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            ###################
            # Train Generator #
            ###################
            loss_G = 0
            sc, sm = mr_stft_loss(x_pred_t, x_t)
            loss_G = args['losses']['lambda_sc'] * sc + args['losses'][
                'lambda_sm'] * sm
            netG.zero_grad()
            loss_G.backward()
            optG.step()

            ######################
            # Update tensorboard #
            ######################
            costs.append([loss_G.item(), sc.item(), sm.item(), s_error])

            writer.add_scalar("loss/generator", costs[-1][0], steps)
            writer.add_scalar("loss/convergence", costs[-1][1], steps)
            writer.add_scalar("loss/logmag", costs[-1][2], steps)
            writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps)
            steps += 1

            if steps % args['logging']['save_interval'] == 0:
                st = time.time()
                with torch.no_grad():
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        save_sample(root / ("generated_%d.wav" % i),
                                    args['data']['sampling_rate'], pred_audio)
                        writer.add_audio(
                            "generated/sample_%d.wav" % i,
                            pred_audio,
                            epoch,
                            sample_rate=args['data']['sampling_rate'],
                        )

                torch.save(netG.state_dict(), root / "netG.pt")
                torch.save(optG.state_dict(), root / "optG.pt")

                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netG.state_dict(), root / "best_netG.pt")

                print("Took %5.4fs to generate samples" % (time.time() - st))
                print("-" * 100)

            if steps % args['logging']['log_interval'] == 0:
                print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".
                      format(
                          epoch,
                          iterno,
                          len(train_loader),
                          1000 * (time.time() - start) /
                          args['logging']['log_interval'],
                          np.asarray(costs).mean(0),
                      ))
                costs = []
                start = time.time()
Beispiel #27
0
def train(rank, a, h):
    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'],
                           init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus,
                           rank=rank)

    torch.cuda.manual_seed(h.seed)
    torch.cuda.set_device(rank)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(h).to(device)
    mpd = MultiPeriodDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)

    if rank == 0:
        print(generator)
        os.makedirs(a.checkpoint_path, exist_ok=True)
        print("checkpoints directory : ", a.checkpoint_path)

    if os.path.isdir(a.checkpoint_path):
        cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
        cp_do = scan_checkpoint(a.checkpoint_path, 'do_')

    steps = 0
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        mpd.load_state_dict(state_dict_do['mpd'])
        msd.load_state_dict(state_dict_do['msd'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if h.num_gpus > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(generator.parameters(),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(),
                                                mpd.parameters()),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(a)

    trainset = MelDataset(training_filelist,
                          h.segment_size,
                          h.n_fft,
                          h.num_mels,
                          h.hop_size,
                          h.win_size,
                          h.sampling_rate,
                          h.fmin,
                          h.fmax,
                          n_cache_reuse=0,
                          shuffle=False if h.num_gpus > 1 else True,
                          fmax_loss=h.fmax_for_loss,
                          device=device,
                          fine_tuning=a.fine_tuning,
                          base_mels_path=a.input_mels_dir)

    train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=h.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=h.batch_size,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        validset = MelDataset(validation_filelist,
                              h.segment_size,
                              h.n_fft,
                              h.num_mels,
                              h.hop_size,
                              h.win_size,
                              h.sampling_rate,
                              h.fmin,
                              h.fmax,
                              False,
                              False,
                              n_cache_reuse=0,
                              fmax_loss=h.fmax_for_loss,
                              device=device,
                              fine_tuning=a.fine_tuning,
                              base_mels_path=a.input_mels_dir)
        validation_loader = DataLoader(validset,
                                       num_workers=1,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))

    generator.train()
    mpd.train()
    msd.train()
    for epoch in range(max(0, last_epoch), a.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if h.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            x, y, _, y_mel = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel = torch.autograd.Variable(y_mel.to(device,
                                                     non_blocking=True))
            y = y.unsqueeze(1)

            y_g_hat = generator(x)
            y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft,
                                          h.num_mels, h.sampling_rate,
                                          h.hop_size, h.win_size, h.fmin,
                                          h.fmax_for_loss)

            optim_d.zero_grad()

            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
            loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
                y_df_hat_r, y_df_hat_g)

            # MSD
            y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
            loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
                y_ds_hat_r, y_ds_hat_g)

            loss_disc_all = loss_disc_s + loss_disc_f

            loss_disc_all.backward()
            optim_d.step()

            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45

            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
            loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % a.stdout_interval == 0:
                    with torch.no_grad():
                        mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()

                    print(
                        'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'
                        .format(steps, loss_gen_all, mel_error,
                                time.time() - start_b))

                # checkpointing
                if steps % a.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'generator': (generator.module if h.num_gpus > 1
                                          else generator).state_dict()
                        })
                    checkpoint_path = "{}/do_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'mpd': (mpd.module
                                    if h.num_gpus > 1 else mpd).state_dict(),
                            'msd': (msd.module
                                    if h.num_gpus > 1 else msd).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })

                # Tensorboard summary logging
                if steps % a.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all,
                                  steps)
                    sw.add_scalar("training/mel_spec_error", mel_error, steps)

                # Validation
                if steps % a.validation_interval == 0:  # and steps != 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            x, y, _, y_mel = batch
                            y_g_hat = generator(x.to(device))
                            y_mel = torch.autograd.Variable(
                                y_mel.to(device, non_blocking=True))
                            y_g_hat_mel = mel_spectrogram(
                                y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                h.sampling_rate, h.hop_size, h.win_size,
                                h.fmin, h.fmax_for_loss)
                            val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()

                            if j <= 4:
                                if steps == 0:
                                    sw.add_audio('gt/y_{}'.format(j), y[0],
                                                 steps, h.sampling_rate)
                                    sw.add_figure('gt/y_spec_{}'.format(j),
                                                  plot_spectrogram(x[0]),
                                                  steps)

                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             y_g_hat[0], steps,
                                             h.sampling_rate)
                                y_hat_spec = mel_spectrogram(
                                    y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                    h.sampling_rate, h.hop_size, h.win_size,
                                    h.fmin, h.fmax)
                                sw.add_figure(
                                    'generated/y_hat_spec_{}'.format(j),
                                    plot_spectrogram(
                                        y_hat_spec.squeeze(0).cpu().numpy()),
                                    steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/mel_spec_error", val_err,
                                      steps)

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Beispiel #28
0
def generate(output_directory, tensorboard_directory,
             num_samples,
             ckpt_path, ckpt_iter):
    """
    Generate audio based on ground truth mel spectrogram

    Parameters:
    output_directory (str):         save generated speeches to this path
    tensorboard_directory (str):    save tensorboard events to this path
    num_samples (int):              number of samples to generate, default is 4
    ckpt_path (str):                checkpoint path
    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 
                                    automitically selects the maximum iteration if 'max' is selected
    """

    # generate experiment (local) path
    local_path = "ch{}_T{}_betaT{}".format(wavenet_config["res_channels"], 
                                           diffusion_config["T"], 
                                           diffusion_config["beta_T"])
    
    # Get shared output_directory ready
    output_directory = os.path.join('exp', local_path, output_directory)
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory, flush=True)

    # map diffusion hyperparameters to gpu
    for key in diffusion_hyperparams:
        if key is not "T":
            diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()

    # predefine model
    net = WaveNet(**wavenet_config).cuda()
    print_size(net)

    # load checkpoint
    ckpt_path = os.path.join('exp', local_path, ckpt_path)
    if ckpt_iter == 'max':
        ckpt_iter = find_max_epoch(ckpt_path)
    model_path = os.path.join(ckpt_path, '{}.pkl'.format(ckpt_iter))
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        net.load_state_dict(checkpoint['model_state_dict'])
        print('Successfully loaded model at iteration {}'.format(ckpt_iter))
    except:
        raise Exception('No valid model found')

    # predefine audio shape
    audio_length = trainset_config["segment_length"]  # 16000
    print('begin generating audio of length %s' % audio_length)

    # inference
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()

    generated_audio = sampling(net, (num_samples,1,audio_length), 
                               diffusion_hyperparams)
    
    end.record()
    torch.cuda.synchronize()
    print('generated {} utterances of random_digit at iteration {} in {} seconds'.format(num_samples,
                                                                               ckpt_iter, 
                                                                               int(start.elapsed_time(end)/1000)))

    # save audio to .wav
    for i in range(num_samples):
        outfile = '{}_{}_{}k_{}.wav'.format(wavenet_config["res_channels"], 
                                        diffusion_config["T"], 
                                        ckpt_iter // 1000, 
                                        i)
        wavwrite(os.path.join(output_directory, outfile), 
                    trainset_config["sampling_rate"],
                    generated_audio[i].squeeze().cpu().numpy())

        # save audio to tensorboard
        tb = SummaryWriter(os.path.join('exp', local_path, tensorboard_directory))
        tb.add_audio(tag=outfile, snd_tensor=generated_audio[i], sample_rate=trainset_config["sampling_rate"])
        tb.close()

    print('saved generated samples at iteration %s' % ckpt_iter)
Beispiel #29
0
class BaseTrainer:
    def __init__(self, dist, rank, config, resume, only_validation, model,
                 loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        model = DistributedDataParallel(model.to(rank), device_ids=[rank])
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustics"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        self.torch_stft = partial(stft,
                                  n_fft=n_fft,
                                  hop_length=hop_length,
                                  win_length=win_length)
        self.torch_istft = partial(istft,
                                   n_fft=n_fft,
                                   hop_length=hop_length,
                                   win_length=win_length)
        self.librosa_stft = partial(librosa.stft,
                                    n_fft=n_fft,
                                    hop_length=hop_length,
                                    win_length=win_length)
        self.librosa_istft = partial(librosa.istft,
                                     hop_length=hop_length,
                                     win_length=win_length)

        # Trainer.train in the config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config[
            "save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one."

        # Trainer.validation in the config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config[
            "validation_interval"]
        self.save_max_metric_score = self.validation_config[
            "save_max_metric_score"]
        assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one."

        # Trainer.visualization in the config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute(
        ) / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        # Debug validation, which skips training
        self.only_validation = only_validation

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir],
                              resume=resume)

            self.writer = SummaryWriter(self.logs_dir.as_posix(),
                                        max_queue=5,
                                        flush_secs=30)
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1)

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open(
                (self.save_dir /
                 f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(),
                    "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.

        Args:
            model_path (Path): The file path of the *.tar file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(
        ), f"The file {model_path.as_posix()} is not exist. please check path."

        model_checkpoint = torch.load(model_path.as_posix(),
                                      map_location="cpu")
        self.model.load_state_dict(model_checkpoint["model"], strict=False)
        self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """
        Resume the experiment from the latest checkpoint.
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute(
        ) / "latest_model.tar"
        assert latest_model_path.exists(
        ), f"{latest_model_path} does not exist, can not load latest checkpoint."

        # Make sure all processes (GPUs) do not start loading before the saving is finished.
        # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work
        self.dist.barrier()

        # Load it on the CPU and later use .to(device) on the model
        # Maybe slightly slow than use map_location="cuda:<...>"
        # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion
        checkpoint = torch.load(latest_model_path.as_posix(),
                                map_location="cpu")

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scaler.load_state_dict(checkpoint["scaler"])

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            self.model.module.load_state_dict(checkpoint["model"])
        else:
            self.model.load_state_dict(checkpoint["model"])

        # self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch."
            )

    def _save_checkpoint(self, epoch, is_best_epoch=False):
        """
        Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of:
            - epoch
            - best metric score in historical epochs
            - optimizer parameters
            - model parameters

        Args:
            is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True),
                                the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar".
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict(),
            "scaler": self.scaler.state_dict()
        }

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict["model"] = self.model.module.state_dict()
        else:
            state_dict["model"] = self.model.state_dict()

        # Saved in "latest_model.tar"
        # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc.
        # New checkpoint will overwrite the older one.
        torch.save(state_dict,
                   (self.checkpoints_dir / "latest_model.tar").as_posix())

        # "model_{epoch_number}.pth"
        # Contains only model.
        torch.save(state_dict["model"],
                   (self.checkpoints_dir /
                    f"model_{str(epoch).zfill(4)}.pth").as_posix())

        # If the model get a best metric score (means "is_best_epoch=True") in the current epoch,
        # the model checkpoint will be saved as "best_model.tar"
        # The newer best-scored checkpoint will overwrite the older one.
        if is_best_epoch:
            print(
                self.color_tool.red(
                    f"\t Found a best score in the {epoch} epoch, saving..."))
            torch.save(state_dict,
                       (self.checkpoints_dir / "best_model.tar").as_posix())

    def _is_best_epoch(self, score, save_max_metric_score=True):
        """
        Check if the current model got the best metric score
        """
        if save_max_metric_score and score >= self.best_score:
            self.best_score = score
            return True
        elif not save_max_metric_score and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _print_networks(models: list):
        print(
            f"This project contains {len(models)} models, the number of the parameters is: "
        )

        params_of_all_networks = 0
        for idx, model in enumerate(models, start=1):
            params_of_network = 0
            for param in model.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {idx}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(
            f"The amount of parameters in the project is {params_of_all_networks / 1e6} million."
        )

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self,
                                 noisy,
                                 enhanced,
                                 clean,
                                 name,
                                 epoch,
                                 mark=""):
        self.writer.add_audio(f"{mark}_Speech/{name}_Noisy",
                              noisy,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced",
                              enhanced,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Clean",
                              clean,
                              epoch,
                              sample_rate=16000)

        # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech
        noisy_mag, _ = librosa.magphase(
            self.librosa_stft(noisy, n_fft=320, hop_length=160,
                              win_length=320))
        enhanced_mag, _ = librosa.magphase(
            self.librosa_stft(enhanced,
                              n_fft=320,
                              hop_length=160,
                              win_length=320))
        clean_mag, _ = librosa.magphase(
            self.librosa_stft(clean, n_fft=320, hop_length=160,
                              win_length=320))
        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]):
            axes[k].set_title(f"mean: {np.mean(mag):.3f}, "
                              f"std: {np.std(mag):.3f}, "
                              f"max: {np.max(mag):.3f}, "
                              f"min: {np.min(mag):.3f}")
            librosa.display.specshow(librosa.amplitude_to_db(mag),
                                     cmap="magma",
                                     y_axis="linear",
                                     ax=axes[k],
                                     sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self,
                              noisy_list,
                              clean_list,
                              enhanced_list,
                              metrics_list,
                              epoch,
                              num_workers=10,
                              mark=""):
        """
        Get metrics on validation dataset by paralleling.

        Notes:
            1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are
             used for checking if the current epoch is a "best epoch."
            2. If you want to use a new metric, you must register it in "util.metrics" file.
        """
        assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence."

        # Check if the metric is registered in "util.metrics" file.
        for i in metrics_list:
            assert i in metrics.REGISTERED_METRICS.keys(
            ), f"{i} is not registered, please check 'util.metrics' file."

        stoi_mean = 0.0
        wb_pesq_mean = 0.0
        for metric_name in metrics_list:
            score_on_noisy = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, noisy_list))
            score_on_enhanced = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, enhanced_list))

            # Add the mean value of the metric to tensorboard
            mean_score_on_noisy = np.mean(score_on_noisy)
            mean_score_on_enhanced = np.mean(score_on_enhanced)
            self.writer.add_scalars(f"{mark}_Validation/{metric_name}", {
                "Noisy": mean_score_on_noisy,
                "Enhanced": mean_score_on_enhanced
            }, epoch)

            if metric_name == "STOI":
                stoi_mean = mean_score_on_enhanced

            if metric_name == "WB_PESQ":
                wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced)

        return (stoi_mean + wb_pesq_mean) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            if self.rank == 0:
                print(
                    self.color_tool.yellow(
                        f"{'=' * 15} {epoch} epoch {'=' * 15}"))
                print("[0 seconds] Begin training...")

            # [debug validation] Only run validation (only use the first GPU (process))
            # inference + calculating metrics + saving checkpoints
            if self.only_validation and self.rank == 0:
                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

                # Skip the following regular training, saving checkpoints, and validation
                continue

            # Regular training
            timer = ExecutionTime()
            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            #  Regular save checkpoints
            if self.rank == 0 and self.save_checkpoint_interval != 0 and (
                    epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            # Regular validation
            if self.rank == 0 and (epoch % self.validation_interval == 0):
                print(
                    f"[{timer.duration()} seconds] Training has finished, validation is in progress..."
                )

                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

            print(f"[{timer.duration()} seconds] This epoch is finished.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError
Beispiel #30
0
def main():    
    args = parse_args()        
    root = Path(args.save_path)
    load_root = Path(args.load_path) if args.load_path else None
    print(load_root)
    root.mkdir(parents=True, exist_ok=True)

    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(root / "args.yml", "w") as f:
        yaml.dump(args, f)
    writer = SummaryWriter(str(root))

    #######################
    # Load PyTorch Models #
    #######################
    netG = Generator(args.n_mel_channels).cuda()
    fft = Audio2Mel(n_mel_channels=args.n_mel_channels, mel_fmin=40, mel_fmax=None, sampling_rate=22050).cuda()

    print(netG)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))

    if load_root and load_root.exists():
        netG.load_state_dict(torch.load(load_root / "netG.pt"))
        optG.load_state_dict(torch.load(load_root / "optG.pt"))        
        print('checkpoints loaded')

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(
        Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=22050
    )
    test_set = AudioDataset(
        Path(args.data_path) / "test_files.txt",
        ((22050*4//256)//32)*32*256,
        sampling_rate=22050,
        augment=False,
    )

    train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=1)

    mr_stft_loss = MultiResolutionSTFTLoss().cuda()
    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.cuda()
        s_t = fft(x_t).detach()

        test_voc.append(s_t.cuda())
        test_audio.append(x_t.cpu())

        audio = x_t.squeeze().cpu()
        save_sample(root / ("original_%d.wav" % i), 22050, audio)
        writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=22050)

        if i == args.n_test_samples - 1:
            break

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True
    best_mel_reconst = 1000000
    steps = 0
    for epoch in range(1, args.epochs + 1):
        for iterno, x_t in enumerate(train_loader):            
            x_t = x_t.cuda()            
            s_t = fft(x_t).detach()
            n = torch.randn(x_t.shape[0], 128, 1).cuda()
            x_pred_t = netG(s_t.cuda(), n)            
            
            ###################
            # Train Generator #
            ###################            
            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()
                
            sc_loss, mag_loss = mr_stft_loss(x_pred_t, x_t)
            
            loss_G = sc_loss + mag_loss
            
            netG.zero_grad()
            loss_G.backward()
            optG.step()

            ######################
            # Update tensorboard #
            ######################
            costs.append([loss_G.item(), sc_loss.item(), mag_loss.item(), s_error])
            
            writer.add_scalar("loss/generator", costs[-1][0], steps)
            writer.add_scalar("loss/spectral_convergence", costs[-1][1], steps)
            writer.add_scalar("loss/log_spectrum", costs[-1][2], steps)
            writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps)
            steps += 1

            if steps % args.save_interval == 0:
                st = time.time()
                with torch.no_grad():
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        n = torch.randn(1, 128, 10).cuda()
                        pred_audio = netG(voc, n)
                        pred_audio = pred_audio.squeeze().cpu()
                        save_sample(root / ("generated_%d.wav" % i), 22050, pred_audio)
                        writer.add_audio(
                            "generated/sample_%d.wav" % i,
                            pred_audio,
                            epoch,
                            sample_rate=22050,
                        )

                torch.save(netG.state_dict(), root / "netG.pt")
                torch.save(optG.state_dict(), root / "optG.pt")
                                
                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]                    
                    torch.save(netG.state_dict(), root / "best_netG.pt")

                print("Took %5.4fs to generate samples" % (time.time() - st))
                print("-" * 100)

            if steps % args.log_interval == 0:
                print(
                    "Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".format(
                        epoch,
                        iterno,
                        len(train_loader),
                        1000 * (time.time() - start) / args.log_interval,
                        np.asarray(costs).mean(0),
                    )
                )
                costs = []
                start = time.time()