Esempio n. 1
0
def train(log_dir, args):
  commit = get_git_commit() if args.git else 'None'
  checkpoint_path = os.path.join(log_dir, 'model.ckpt')
  input_path = os.path.join(args.base_dir, args.input)
  log('Checkpoint path: %s' % checkpoint_path)
  log('Loading training data from: %s' % input_path)
  log('Using model: %s' % args.model)
  log(hparams_debug_string())

  # Set up DataFeeder:
  coord = tf.train.Coordinator()
  with tf.variable_scope('datafeeder') as scope:
    feeder = DataFeeder(coord, input_path, hparams)

  # Set up model:
  global_step = tf.Variable(0, name='global_step', trainable=False)
  with tf.variable_scope('model') as scope:
    model = create_model(args.model, hparams)
    model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
    model.add_loss()
    model.add_optimizer(global_step)
    stats = add_stats(model)

  # Bookkeeping:
  step = 0
  time_window = ValueWindow(100)
  loss_window = ValueWindow(100)
  saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

  # Train!
  with tf.Session() as sess:
    try:
      summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
      sess.run(tf.global_variables_initializer())

      if args.restore_step:
        # Restore from a checkpoint if the user requested it.
        restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
        saver.restore(sess, restore_path)
        log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
      else:
        log('Starting new training run at commit: %s' % commit, slack=True)

      feeder.start_in_session(sess)

      while not coord.should_stop():
        start_time = time.time()
        step, loss, opt = sess.run([global_step, model.loss, model.optimize])
        time_window.append(time.time() - start_time)
        loss_window.append(loss)
        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
          step, time_window.average, loss, loss_window.average)
        log(message, slack=(step % args.checkpoint_interval == 0))

        if loss > 100 or math.isnan(loss):
          log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
          raise Exception('Loss Exploded')

        if step % args.summary_interval == 0:
          log('Writing summary at step: %d' % step)
          summary_writer.add_summary(sess.run(stats), step)

        if step % args.checkpoint_interval == 0:
          log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
          saver.save(sess, checkpoint_path, global_step=step)
          log('Saving audio and alignment...')
          input_seq, spectrogram, alignment = sess.run([
            model.inputs[0], model.linear_outputs[0], model.alignments[0]])
          waveform = audio.inv_spectrogram(spectrogram.T)
          audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
          plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
            info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))
          log('Input: %s' % sequence_to_text(input_seq))

    except Exception as e:
      log('Exiting due to exception: %s' % e, slack=True)
      traceback.print_exc()
      coord.request_stop(e)
def train(log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    tf.logging.info('Checkpoint path: %s' % checkpoint_path)
    tf.logging.info('Loading training data from: %s' % input_path)
    tf.logging.info('Using model: %s' % args.model)
    tf.logging.info(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                tf.logging.info('Resuming from checkpoint: %s' % restore_path)
            else:
                tf.logging.info('Starting new training run')

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                tf.logging.info(message)

                if loss > 100 or math.isnan(loss):
                    tf.logging.info('Loss exploded to %.05f at step %d!' %
                                    (loss, step))
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    tf.logging.info('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    tf.logging.info('Saving checkpoint to: %s-%d' %
                                    (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    tf.logging.info('Saving audio and alignment...')
                    input_seq, spectrogram, alignment = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    # Teacher forcing results.
                    audio.save_wav(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, step=%d, loss=%.5f' %
                        (args.model, time_string(), step, loss))
                    tf.logging.info('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            tf.logging.info('Exiting due to exception: %s' % e)
            traceback.print_exc()
            coord.request_stop(e)
Esempio n. 3
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration, x):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y
        text_inputs = x[0]
        speaker_ids = x[5]
        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration, dataformats='HWC')
        # 记录一下合成的语音效果。
        audio_predicted = inv_linearspectrogram(mel_outputs[idx].data.cpu().numpy())
        self.add_audio(
            'audio_predicted',
            torch.from_numpy(audio_predicted),
            iteration, sample_rate=default_hparams.sample_rate
        )
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        audio_target = inv_linearspectrogram(mel_targets[idx].data.cpu().numpy())
        self.add_audio(
            'audio_target',
            torch.from_numpy(audio_target),
            iteration, sample_rate=default_hparams.sample_rate
        )

        spk = int(speaker_ids[idx].data.cpu().numpy().flatten()[0])
        ph_ids = text_inputs[idx].data.cpu().numpy().flatten()
        phs_text = sequence_to_text(ph_ids)
        phs_size = len(ph_ids)
        reduced_loss = float(reduced_loss)
        audt_duration = int(len(audio_target) / (default_hparams.sample_rate / 1000))
        audp_duration = int(len(audio_predicted) / (default_hparams.sample_rate / 1000))
        spect_shape = mel_targets[idx].data.cpu().numpy().shape
        specp_shape = mel_outputs[idx].data.cpu().numpy().shape
        align_shape = alignments[idx].data.cpu().numpy().T.shape
        out_text = dict(speaker_id=spk, phonemes=phs_text, phonemes_size=phs_size, validation_loss=reduced_loss,
                        audio_target_ms=audt_duration, audio_predicted_ms=audp_duration,
                        spectrogram_target_shape=str(spect_shape), spectrogram_predicted_shape=str(specp_shape),
                        alignment_shape=str(align_shape))
        out_text = json.dumps(out_text, indent=4, ensure_ascii=False)
        out_text = f'<pre>{out_text}</pre>'  # 支持html标签
        self.add_text(
            'text',
            out_text,
            iteration
        )
Esempio n. 4
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets,
                         feeder.stop_token_targets, global_step)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    run_options = None  #tf.RunOptions(report_tensor_allocations_upon_oom = True)

    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                checkpoint_state = tf.train.get_checkpoint_state(log_dir)
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                if checkpoint_state is not None:
                    saver.restore(sess, checkpoint_state.model_checkpoint_path)
                    log('Resuming from checkpoint: %s at commit: %s' %
                        (checkpoint_state.model_checkpoint_path, commit),
                        slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            tf.train.write_graph(sess.graph.as_graph_def(),
                                 '.',
                                 os.path.join(log_dir, 'tacotron_model.pbtxt'),
                                 as_text=True)
            tf.train.write_graph(sess.graph.as_graph_def(),
                                 '.',
                                 os.path.join(log_dir, 'tacotron_model.pb'),
                                 as_text=False)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize],
                    options=run_options)
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram, mel_outputs, mel_t, alignment = sess.run(
                        [
                            model.inputs[0], model.linear_outputs[0],
                            model.mel_outputs[0], model.mel_targets[0],
                            model.alignments[0]
                        ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    audio.save_wav(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, loss))
                    plot.plot_spectrogram(
                        mel_outputs,
                        os.path.join(
                            log_dir,
                            'step-{}-eval-mel-spectrogram.png'.format(step)),
                        title='{}, {}, step={}, loss={:.5f}'.format(
                            args.model, time_string(), step, loss),
                        target_spectrogram=mel_t)

                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Esempio n. 5
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # graph
    with tf.Graph().as_default(), tf.device('/cpu:0'):

        #new attributes of hparams
        #hparams.num_GPU = len(GPUs_id)
        #hparams.datasets = eval(args.datasets)
        hparams.datasets = eval(args.datasets)
        hparams.prenet_layer1 = args.prenet_layer1
        hparams.prenet_layer2 = args.prenet_layer2
        hparams.gru_size = args.gru_size
        hparams.attention_size = args.attention_size
        hparams.rnn_size = args.rnn_size
        hparams.enable_fv1 = args.enable_fv1
        hparams.enable_fv2 = args.enable_fv2

        if args.batch_size:
            hparams.batch_size = args.batch_size

        # Multi-GPU settings
        GPUs_id = eval(args.GPUs_id)
        hparams.num_GPU = len(GPUs_id)
        tower_grads = []
        tower_loss = []
        models = []

        global_step = tf.Variable(-1, name='global_step', trainable=False)
        if hparams.decay_learning_rate:
            learning_rate = _learning_rate_decay(hparams.initial_learning_rate,
                                                 global_step, hparams.num_GPU)
        else:
            learning_rate = tf.convert_to_tensor(hparams.initial_learning_rate)
        # Set up DataFeeder:
        coord = tf.train.Coordinator()
        with tf.variable_scope('datafeeder') as scope:
            input_path = os.path.join(args.base_dir, args.input)
            feeder = DataFeeder(coord, input_path, hparams)
            inputs = feeder.inputs
            inputs = tf.split(inputs, hparams.num_GPU, 0)
            input_lengths = feeder.input_lengths
            input_lengths = tf.split(input_lengths, hparams.num_GPU, 0)
            mel_targets = feeder.mel_targets
            mel_targets = tf.split(mel_targets, hparams.num_GPU, 0)
            linear_targets = feeder.linear_targets
            linear_targets = tf.split(linear_targets, hparams.num_GPU, 0)

        # Set up model:
        with tf.variable_scope('model') as scope:
            optimizer = tf.train.AdamOptimizer(learning_rate,
                                               hparams.adam_beta1,
                                               hparams.adam_beta2)
            for i, GPU_id in enumerate(GPUs_id):
                with tf.device('/gpu:%d' % GPU_id):
                    with tf.name_scope('GPU_%d' % GPU_id):

                        if hparams.enable_fv1 or hparams.enable_fv2:
                            net = ResCNN(data=mel_targets[i],
                                         batch_size=hparams.batch_size,
                                         hyparam=hparams)
                            net.inference()

                            voice_print_feature = tf.reduce_mean(
                                net.features, 0)
                        else:
                            voice_print_feature = None

                        models.append(None)
                        models[i] = create_model(args.model, hparams)
                        models[i].initialize(
                            inputs=inputs[i],
                            input_lengths=input_lengths[i],
                            mel_targets=mel_targets[i],
                            linear_targets=linear_targets[i],
                            voice_print_feature=voice_print_feature)
                        models[i].add_loss()
                        """L2 weight decay loss."""
                        if args.weight_decay > 0:
                            costs = []
                            for var in tf.trainable_variables():
                                #if var.op.name.find(r'DW') > 0:
                                costs.append(tf.nn.l2_loss(var))
                                # tf.summary.histogram(var.op.name, var)
                            weight_decay = tf.cast(args.weight_decay,
                                                   tf.float32)
                            cost = models[i].loss
                            models[i].loss += tf.multiply(
                                weight_decay, tf.add_n(costs))
                            cost_pure_wd = tf.multiply(weight_decay,
                                                       tf.add_n(costs))
                        else:
                            cost = models[i].loss
                            cost_pure_wd = tf.constant([0])

                        tower_loss.append(models[i].loss)

                        tf.get_variable_scope().reuse_variables()
                        models[i].add_optimizer(global_step, optimizer)

                        tower_grads.append(models[i].gradients)

            # calculate average gradient
            gradients = average_gradients(tower_grads)

            stats = add_stats(models[0], gradients, learning_rate)
            time.sleep(10)

        # apply average gradient

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            apply_gradient_op = optimizer.apply_gradients(
                gradients, global_step=global_step)

        # Bookkeeping:
        step = 0
        time_window = ValueWindow(100)
        loss_window = ValueWindow(100)
        saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

        # Train!
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            try:
                summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
                sess.run(tf.global_variables_initializer())

                if args.restore_step:
                    # Restore from a checkpoint if the user requested it.
                    restore_path = '%s-%d' % (checkpoint_path,
                                              args.restore_step)
                    saver.restore(sess, restore_path)
                    log('Resuming from checkpoint: %s at commit: %s' %
                        (restore_path, commit),
                        slack=True)
                else:
                    log('Starting new training run at commit: %s' % commit,
                        slack=True)

                feeder.start_in_session(sess)

                while not coord.should_stop():
                    start_time = time.time()
                    model = models[0]

                    step, loss, opt, loss_wd, loss_pure_wd = sess.run([
                        global_step, cost, apply_gradient_op, model.loss,
                        cost_pure_wd
                    ])
                    feeder._batch_in_queue -= 1
                    log('feed._batch_in_queue: %s' %
                        str(feeder._batch_in_queue),
                        slack=True)

                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, loss_wd=%.05f, loss_pure_wd=%.05f]' % (
                        step, time_window.average, loss, loss_window.average,
                        loss_wd, loss_pure_wd)
                    log(message, slack=(step % args.checkpoint_interval == 0))

                    #if the gradient seems to explode, then restore to the previous step
                    if loss > 2 * loss_window.average or math.isnan(loss):
                        log('recover to the previous checkpoint')
                        #tf.reset_default_graph()
                        restore_step = int(
                            (step - 10) / args.checkpoint_interval
                        ) * args.checkpoint_interval
                        restore_path = '%s-%d' % (checkpoint_path,
                                                  restore_step)
                        saver.restore(sess, restore_path)
                        continue

                    if loss > 100 or math.isnan(loss):
                        log('Loss exploded to %.05f at step %d!' %
                            (loss, step),
                            slack=True)
                        raise Exception('Loss Exploded')

                    try:
                        if step % args.summary_interval == 0:
                            log('Writing summary at step: %d' % step)
                            summary_writer.add_summary(sess.run(stats), step)
                    except:
                        pass

                    if step % args.checkpoint_interval == 0:
                        log('Saving checkpoint to: %s-%d' %
                            (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('Saving audio and alignment...')
                        input_seq, spectrogram, alignment = sess.run([
                            model.inputs[0], model.linear_outputs[0],
                            model.alignments[0]
                        ])
                        waveform = audio.inv_spectrogram(spectrogram.T)
                        audio.save_wav(
                            waveform,
                            os.path.join(log_dir, 'step-%d-audio.wav' % step))
                        plot.plot_alignment(
                            alignment,
                            os.path.join(log_dir, 'step-%d-align.png' % step),
                            info='%s, %s, %s, step=%d, loss=%.5f' %
                            (args.model, commit, time_string(), step, loss))
                        log('Input: %s' % sequence_to_text(input_seq))

            except Exception as e:
                log('Exiting due to exception: %s' % e, slack=True)
                traceback.print_exc()
                coord.request_stop(e)
Esempio n. 6
0
    collate_fn = TextMelCollate(config.n_frames_per_step)

    train_dataset = TextMelLoader(config.training_files, config)
    print('len(train_dataset): ' + str(len(train_dataset)))

    valid_dataset = TextMelLoader(config.validation_files, config)
    print('len(valid_dataset): ' + str(len(valid_dataset)))

    text_lengths = []
    mel_lengths = []

    text, mel = valid_dataset[0]
    print('type(mel): ' + str(type(mel)))

    for data in valid_dataset:
        text, mel = data
        text = sequence_to_text(text.numpy().tolist())
        text = ''.join(text)
        mel = mel.numpy()

        print('text: ' + str(text))
        print('mel.size: ' + str(mel.size))
        text_lengths.append(len(text))
        mel_lengths.append(mel.size)
        # print('np.mean(mel): ' + str(np.mean(mel)))
        # print('np.max(mel): ' + str(np.max(mel)))
        # print('np.min(mel): ' + str(np.min(mel)))

    print('np.mean(text_lengths): ' + str(np.mean(text_lengths)))
    print('np.mean(mel_lengths): ' + str(np.mean(mel_lengths)))
Esempio n. 7
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    parent_id = args.pid
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    if parent_id:
        log('Downloading model files from drive')
        download_checkpoints(parent_id)
    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # Train!
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = '%s |Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    time.asctime(), step, time_window.average, loss,
                    loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    list_files = [
                        os.path.join(log_dir, 'checkpoint'),
                        os.path.join(log_dir, 'train.log')
                    ]  #files to be uploaded to drive
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    prefix = saver.save(sess,
                                        checkpoint_path,
                                        global_step=step)
                    list_files.extend(glob.glob(prefix + '.*'))
                    list_files.extend(
                        glob.glob(os.path.join(log_dir, 'events.*')))
                    try:
                        log('Saving audio and alignment...')
                        input_seq, spectrogram, alignment = sess.run([
                            model.inputs[0], model.linear_outputs[0],
                            model.alignments[0]
                        ])
                        waveform = audio.inv_spectrogram(spectrogram.T)
                        info = '\n'.join(
                            textwrap.wrap(
                                '%s, %s, %s, %s, step=%d, loss=%.5f' %
                                (sequence_to_text(input_seq), args.model,
                                 commit, time_string(), step, loss),
                                70,
                                break_long_words=False))
                        audio.save_wav(
                            waveform,
                            os.path.join(log_dir, 'step-%d-audio.wav' % step))
                        plot.plot_alignment(alignment,
                                            os.path.join(
                                                log_dir,
                                                'step-%d-align.png' % step),
                                            info=info)
                        log('Input: %s' % sequence_to_text(input_seq))

                        list_files.append(
                            os.path.join(log_dir, 'step-%d-audio.wav' % step))
                        list_files.append(
                            os.path.join(log_dir, 'step-%d-align.png' % step))
                    except Exception as e:
                        log(str(e))
                        print(e)
                    if parent_id:
                        try:
                            upload_to_drive(list_files, parent_id)
                        except Exception as e:
                            print(e)
                            with open('drive_log.txt', 'a') as ferr:
                                ferr.write('\n\n\n' + time.asctime())
                                ferr.write('\n' + ', '.join(list_files))
                                ferr.write('\n' + str(e))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Esempio n. 8
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    DATA_PATH = {'bznsyp': "BZNSYP", 'ljspeech': "LJSpeech-1.1"}[args.dataset]
    input_path = os.path.join(args.base_dir, 'DATA', DATA_PATH, 'training',
                              'train.txt')
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.lpc_targets, feeder.stop_token_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=999, keep_checkpoint_every_n_hours=2)

    # Train!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                checkpoint_state = tf.train.get_checkpoint_state(log_dir)
                # restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                if checkpoint_state is not None:
                    saver.restore(sess, checkpoint_state.model_checkpoint_path)
                    log('Resuming from checkpoint: %s at commit: %s' %
                        (checkpoint_state.model_checkpoint_path, commit),
                        slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            if args.restore_decoder:
                models = [
                    f for f in os.listdir('pretrain') if f.find('.meta') != -1
                ]
                decoder_ckpt_path = os.path.join(
                    'pretrain', models[0].replace('.meta', ''))

                global_vars = tf.global_variables()
                var_list = []
                valid_scope = [
                    'model/inference/decoder', 'model/inference/post_cbhg',
                    'model/inference/dense', 'model/inference/memory_layer'
                ]
                for v in global_vars:
                    if v.name.find('attention') != -1:
                        continue
                    if v.name.find('Attention') != -1:
                        continue
                    for scope in valid_scope:
                        if v.name.startswith(scope):
                            var_list.append(v)
                decoder_saver = tf.train.Saver(var_list)
                decoder_saver.restore(sess, decoder_ckpt_path)
                print('restore pretrained decoder ...')

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, lpc_targets, alignment = sess.run([
                        model.inputs[0], model.lpc_outputs[0],
                        model.alignments[0]
                    ])
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, loss))
                    np.save(os.path.join(log_dir, 'step-%d-lpc.npy' % step),
                            lpc_targets)
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Esempio n. 9
0
def train(model,
          optimizer,
          scheduler,
          dataset,
          num_epochs,
          batch_size=1,
          save_interval=50,
          exp_name='melnet',
          device=1,
          step=0):
    model.train()
    writer = SummaryWriter(f'runs/{exp_name}')
    sampler = SequentialSampler(dataset)
    batch_sampler = RandomBatchSampler(sampler, batch_size)
    loader = DataLoader(dataset,
                        batch_sampler=batch_sampler,
                        collate_fn=collate_fn,
                        pin_memory=True,
                        num_workers=6)
    tacoteacher = TacoTeacher()
    for _ in tqdm(range(num_epochs), total=num_epochs, unit=' epochs'):
        pbar = tqdm(loader, total=len(loader), unit=' batches')
        for b, (text_batch, audio_batch, text_lengths,
                audio_lengths) in enumerate(pbar):

            # update loop
            text = Variable(text_batch).cuda(device)
            targets = Variable(audio_batch, requires_grad=False).cuda(device)
            stop_targets = make_stop_targets(targets, audio_lengths)
            tacoteacher.set_targets(targets)
            outputs, stop_tokens, attention = model(text, tacoteacher)
            spec_loss = F.mse_loss(outputs, targets)
            stop_loss = F.binary_cross_entropy_with_logits(
                stop_tokens, stop_targets)
            loss = spec_loss + stop_loss
            optimizer.zero_grad()
            loss.backward()
            # clip_grad_norm(model.parameters(), hp.max_grad_norm, norm_type=2)  # prevent exploding grads
            scheduler.step()
            optimizer.step()

            # logging
            pbar.set_description(f'loss: {loss.data[0]:.4f}')
            writer.add_scalar('loss', loss.data[0], step)
            writer.add_scalar('lr', scheduler.lr, step)
            if step % save_interval == 0:
                torch.save(model.state_dict(),
                           f'checkpoints/{exp_name}_{str(step)}.pt')

                # plot the first sample in the batch
                attention_plot = show_attention(attention[0],
                                                return_array=True)
                output_plot = show_spectrogram(outputs.data.permute(1, 2,
                                                                    0)[0],
                                               sequence_to_text(text.data[0]),
                                               return_array=True)
                target_plot = show_spectrogram(targets.data.permute(1, 2,
                                                                    0)[0],
                                               sequence_to_text(text.data[0]),
                                               return_array=True)
                writer.add_image('attention', attention_plot, step)
                writer.add_image('output', output_plot, step)
                writer.add_image('target', target_plot, step)
                for name, param in model.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         step,
                                         bins='doane')
            step += 1
Esempio n. 10
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    ## input path is lists of both postive path and negtiva path
    input_path_pos = os.path.join(args.base_dir, args.input_pos)
    input_path_neg = os.path.join(args.base_dir, args.input_neg)

    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading positive training data from: %s' % input_path_pos)
    log('Loading negative training data from: %s' % input_path_neg)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path_pos, input_path_neg, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs_pos, feeder.input_lengths_pos,
                         feeder.mel_targets_pos, feeder.linear_targets_pos,
                         feeder.mel_targets_neg, feeder.linear_targets_neg,
                         feeder.labels_pos, feeder.labels_neg)
        model.add_loss()
        model.add_optimizer(global_step)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            #summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                # train d
                sess.run(model.d_optimize)
                # train g
                step, rec_loss, style_loss, d_loss, g_loss, _ = sess.run([
                    global_step, model.rec_loss, model.style_loss,
                    model.d_loss, model.g_loss, model.g_optimize
                ])
                time_window.append(time.time() - start_time)
                message = 'Step %-7d [%.03f sec/step, rec_loss=%.05f, style_loss=%.05f, d_loss=%.05f, g_loss=%.05f]' % (
                    step, time_window.average, rec_loss, style_loss, d_loss,
                    g_loss)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram_pos, spectrogram_neg, alignment_pos, alignment_neg = sess.run(
                        [
                            model.inputs[0], model.linear_outputs_pos[0],
                            model.linear_outputs_neg[0],
                            model.alignments_pos[0], model.alignments_neg[0]
                        ])

                    waveform_pos = audio.inv_spectrogram(spectrogram_pos.T)
                    waveform_neg = audio.inv_spectrogram(spectrogram_neg.T)
                    audio.save_wav(
                        waveform_pos,
                        os.path.join(log_dir, 'step-%d-audio_pos.wav' % step))
                    audio.save_wav(
                        waveform_neg,
                        os.path.join(log_dir, 'step-%d-audio_neg.wav' % step))
                    plot.plot_alignment(
                        alignment_pos,
                        os.path.join(log_dir, 'step-%d-align_pos.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, rec_loss))
                    plot.plot_alignment(
                        alignment_neg,
                        os.path.join(log_dir, 'step-%d-align_neg.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, rec_loss))
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)