コード例 #1
0
ファイル: train_evaluate.py プロジェクト: yanivbl6/advoc
def train(fps, args):
    # Initialize model
    if args.model_type == "regular":
        model = Advoc(Modes.TRAIN)
    elif args.model_type == "small":
        model = AdvocSmall(Modes.TRAIN)
    else:
        raise NotImplementedError()

    model, summary = override_model_attrs(model, args.model_overrides)
    model.audio_fs = args.data_sample_rate

    print('-' * 80)
    print(summary)
    print('-' * 80)

    # Load data
    with tf.name_scope('loader'):
        x_magspec, x_wav = decode_extract_and_batch(
            fps,
            batch_size=model.train_batch_size,
            slice_len=model.subseq_len,
            audio_fs=model.audio_fs,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=4,
            extract_type='magspec',
            extract_parallel_calls=8,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=512,
            slice_first_only=args.data_slice_first_only,
            slice_randomize_offset=args.data_slice_randomize_offset,
            slice_overlap_ratio=args.data_slice_overlap_ratio,
            slice_pad_end=args.data_slice_pad_end,
            prefetch_size=model.train_batch_size * 8,
            prefetch_gpu_num=0)

    # Create model
    spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs)

    x_melspec = spectral.mag_to_mel_linear_spec(x_magspec)
    x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec,
                                                         transform='inverse')

    model(x_inverted_magspec, x_magspec, x_wav, x_melspec)

    #Train
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_ckpt_every_nsecs,
            save_summaries_secs=args.train_summary_every_nsecs) as sess:

        _step = 0
        while not sess.should_stop() and _step < args.max_steps:
            _step = model.train_loop(sess)

    print("Done!")
コード例 #2
0
ファイル: train_evaluate.py プロジェクト: yanivbl6/advoc
def eval(fps, args):
    if args.eval_dataset_name is not None:
        eval_dir = os.path.join(args.train_dir,
                                'eval_{}'.format(args.eval_dataset_name))
    else:
        eval_dir = os.path.join(args.train_dir, 'eval_valid')
    if not os.path.isdir(eval_dir):
        os.makedirs(eval_dir)

    if args.model_type == "regular":
        model = Advoc(Modes.EVAL)
    elif args.model_type == "small":
        model = AdvocSmall(Modes.EVAL)
    else:
        raise NotImplementedError()

    model, summary = override_model_attrs(model, args.model_overrides)
    model.audio_fs = args.data_sample_rate

    print('-' * 80)
    print(summary)
    print('-' * 80)

    with tf.name_scope('loader'):
        x_magspec, x_wav = decode_extract_and_batch(
            fps,
            batch_size=model.eval_batch_size,
            slice_len=model.subseq_len,
            audio_fs=model.audio_fs,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=4,
            extract_type='magspec',
            extract_parallel_calls=8,
            repeat=False,
            shuffle=False,
            shuffle_buffer_size=None,
            slice_first_only=args.data_slice_first_only,
            slice_randomize_offset=False,
            slice_overlap_ratio=0.,
            slice_pad_end=True,
            prefetch_size=None,
            prefetch_gpu_num=None)

    spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs)
    x_melspec = spectral.mag_to_mel_linear_spec(x_magspec)
    x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec,
                                                         transform='inverse')

    with tf.variable_scope("generator") as vs:
        if model.generator_type == "pix2pix":
            gen_magspec = model.build_generator(x_inverted_magspec)
        elif model.generator_type == "linear":
            gen_magspec = model.build_linear_generator(x_inverted_magspec)
        elif model.generator_type == "linear+pix2pix":
            _temp_spec = model.build_linear_generator(x_melspec)
            gen_magspec = model.build_linear_generator(_temp_spec)
        elif model.generator_type == "interp+pix2pix":
            _temp_spec = tf.image.resize_images(x_melspec,
                                                [model.subseq_len, 513])
            gen_magspec = model.build_linear_generator(_temp_spec)
        else:
            raise NotImplementedError()

        G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope=vs.name)

    gen_loss_L1 = tf.reduce_mean(tf.abs(x_magspec - gen_magspec))
    gan_step = tf.train.get_or_create_global_step()
    gan_saver = tf.train.Saver(var_list=G_vars + [gan_step], max_to_keep=1)

    all_gen_loss_L1 = tf.placeholder(tf.float32, [None])

    summaries = [
        tf.summary.scalar('gen_loss_L1', tf.reduce_mean(all_gen_loss_L1)),
    ]
    summaries = tf.summary.merge(summaries)

    # Create summary writer
    summary_writer = tf.summary.FileWriter(eval_dir)
    ckpt_fp = None
    _best_gen_loss_l1 = np.inf

    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            ckpt_fp = latest_ckpt_fp
            print('Evaluating {}'.format(ckpt_fp))

            with tf.Session() as sess:
                gan_saver.restore(sess, latest_ckpt_fp)
                _step = sess.run(gan_step)
                _all_gen_loss_L1 = []

                while True:
                    try:
                        _gen_loss_L1, _gen_magspec, _x_magspec = sess.run(
                            [gen_loss_L1, gen_magspec, x_magspec])
                    except tf.errors.OutOfRangeError:
                        break

                    _all_gen_loss_L1.append(_gen_loss_L1)

                _all_gen_loss_L1 = np.array(_all_gen_loss_L1)

                _summaries = sess.run(summaries, {
                    all_gen_loss_L1: _all_gen_loss_L1,
                })
                summary_writer.add_summary(_summaries, _step)
                _gen_loss_L1_np = np.mean(_all_gen_loss_L1)

                if _gen_loss_L1_np < _best_gen_loss_l1:
                    gan_saver.save(sess,
                                   os.path.join(eval_dir, 'best_gen_loss_l1'),
                                   _step)
                    print("Saved best gen loss l1!")
            print('Done!')
        time.sleep(1)
コード例 #3
0
ファイル: train_evaluate.py プロジェクト: yanivbl6/advoc
def infer(fps, args):
    if args.infer_dataset_name is not None:
        infer_dir = os.path.join(args.train_dir,
                                 'infer_{}'.format(args.infer_dataset_name))
    else:
        infer_dir = os.path.join(args.train_dir, 'infer_valid')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    if args.model_type == "regular":
        model = Advoc(Modes.INFER)
    elif args.model_type == "small":
        model = AdvocSmall(Modes.INFER)
    else:
        raise NotImplementedError()

    model, summary = override_model_attrs(model, args.model_overrides)
    model.audio_fs = args.data_sample_rate

    print('-' * 80)
    print(summary)
    print('-' * 80)

    with tf.name_scope('loader'):
        x_magspec, x_wav = decode_extract_and_batch(
            fps,
            batch_size=args.infer_batch_size,
            slice_len=model.subseq_len,
            audio_fs=model.audio_fs,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=4,
            extract_type='magspec',
            extract_parallel_calls=8,
            repeat=False,
            shuffle=False,
            shuffle_buffer_size=None,
            slice_randomize_offset=False,
            slice_overlap_ratio=0.,
            slice_pad_end=True,
            prefetch_size=None,
            prefetch_gpu_num=None)

    spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs)
    x_melspec = spectral.mag_to_mel_linear_spec(x_magspec)
    x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec,
                                                         transform='inverse')

    with tf.variable_scope("generator") as vs:
        if model.generator_type == "pix2pix":
            gen_magspec = model.build_generator(x_inverted_magspec)
        elif model.generator_type == "linear":
            gen_magspec = model.build_linear_generator(x_inverted_magspec)
        elif model.generator_type == "linear+pix2pix":
            _temp_spec = model.build_linear_generator(x_melspec)
            gen_magspec = model.build_linear_generator(_temp_spec)
        elif model.generator_type == "interp+pix2pix":
            _temp_spec = tf.image.resize_images(x_melspec,
                                                [model.subseq_len, 513])
            gen_magspec = model.build_linear_generator(_temp_spec)
        else:
            raise NotImplementedError()
        G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope=vs.name)

    step = tf.train.get_or_create_global_step()
    gan_saver = tf.train.Saver(var_list=G_vars + [step], max_to_keep=1)

    input_audio = tf.py_func(spectral.audio_from_mag_spec,
                             [x_inverted_magspec[0]],
                             tf.float32,
                             stateful=False)
    target_audio = tf.py_func(spectral.audio_from_mag_spec, [x_magspec[0]],
                              tf.float32,
                              stateful=False)
    gen_audio = tf.py_func(spectral.audio_from_mag_spec, [gen_magspec[0]],
                           tf.float32,
                           stateful=False)

    # dont know why i rehspae them this way. just following past convention.
    input_audio = tf.reshape(input_audio, [1, -1, 1, 1])
    target_audio = tf.reshape(target_audio, [1, -1, 1, 1])
    gen_audio = tf.reshape(gen_audio, [1, -1, 1, 1])

    summaries = [
        tf.summary.audio('infer_x_wav', x_wav[:, :, 0, :], model.audio_fs),
        tf.summary.audio('infer_gen_audio', gen_audio[:, :, 0, :],
                         model.audio_fs),
        tf.summary.audio('target_audio', target_audio[:, :, 0, :],
                         model.audio_fs),
        tf.summary.audio('infer_input_audio', input_audio[:, :, 0, :],
                         model.audio_fs)
    ]

    summaries = tf.summary.merge(summaries)
    # Create saver and summary writer
    summary_writer = tf.summary.FileWriter(infer_dir)

    if args.infer_ckpt_path is not None:
        # Infering From a particular Checkpoint
        ckpt_fp = args.infer_ckpt_path
        print('Infereing From {}'.format(ckpt_fp))

        with tf.Session() as sess:
            gan_saver.restore(sess, ckpt_fp)
            _step = sess.run(step)
            # Just one batch at a time
            while True:
                try:
                    _summaries, mel_np, est_np, act_np, gen_np = sess.run([
                        summaries, x_melspec, x_inverted_magspec, x_magspec,
                        gen_magspec
                    ])
                    summary_writer.add_summary(_summaries, _step)

                except tf.errors.OutOfRangeError:
                    break
            print('Done!')

    else:
        # Continuous Inference
        ckpt_fp = None
        while True:
            with tf.Session() as sess:
                latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
                if latest_ckpt_fp != ckpt_fp:
                    ckpt_fp = latest_ckpt_fp
                    print('Infereing From {}'.format(ckpt_fp))
                    gan_saver.restore(sess, ckpt_fp)
                    _step = sess.run(step)

                    while True:
                        try:
                            _summaries, mel_np, est_np, act_np, gen_np = sess.run(
                                [
                                    summaries, x_melspec, x_inverted_magspec,
                                    x_magspec, gen_magspec
                                ])
                            summary_writer.add_summary(_summaries, _step)

                        except tf.errors.OutOfRangeError:
                            break
                    print("Done!")
                time.sleep(1)

    raise NotImplementedError()
コード例 #4
0
    def __call__(self, x, target, x_wav, x_mel_spec):

        self.spectral = SpectralUtil(n_mels=self.n_mels, fs=self.audio_fs)

        try:
            batch_size = int(x.get_shape()[0])
        except:
            batch_size = tf.shape(x)[0]

        with tf.variable_scope("generator"):
            if self.generator_type == "pix2pix":
                gen_mag_spec = self.build_generator(x)
            elif self.generator_type == "linear":
                gen_mag_spec = self.build_linear_generator(x)
            elif self.generator_type == "linear+pix2pix":
                temp_spec = self.build_linear_generator(x_mel_spec)
                gen_mag_spec = self.build_linear_generator(temp_spec)
            elif self.generator_type == "interp+pix2pix":
                _temp_spec = tf.image.resize_images(x_mel_spec,
                                                    [self.subseq_len, 513])
                gen_mag_spec = self.build_linear_generator(_temp_spec)
            else:
                raise NotImplementedError()

        with tf.name_scope("real_discriminator"):
            with tf.variable_scope("discriminator"):
                predict_real = self.build_discriminator(x, target)

        with tf.name_scope("fake_discriminator"):
            with tf.variable_scope("discriminator", reuse=True):
                predict_fake = self.build_discriminator(x, gen_mag_spec)

        discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) +
                                        tf.log(1 - predict_fake + EPS)))
        gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
        gen_loss_L1 = tf.reduce_mean(tf.abs(target - gen_mag_spec))

        if self.gan_weight > 0:
            gen_loss = gen_loss_GAN * self.gan_weight + gen_loss_L1 * self.l1_weight
        else:
            gen_loss = gen_loss_L1 * self.l1_weight

        self.D_vars = D_vars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("discriminator")
        ]
        self.G_vars = G_vars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("generator")
        ]

        D_opt = tf.train.AdamOptimizer(0.0002, 0.5)
        G_opt = tf.train.AdamOptimizer(0.0002, 0.5)

        self.step = step = tf.train.get_or_create_global_step()
        self.G_train_op = G_opt.minimize(gen_loss,
                                         var_list=G_vars,
                                         global_step=self.step)

        self.D_train_op = D_opt.minimize(discrim_loss, var_list=D_vars)

        input_audio = tf.py_func(self.spectral.audio_from_mag_spec, [x[0]],
                                 tf.float32,
                                 stateful=False)
        target_audio = tf.py_func(self.spectral.audio_from_mag_spec,
                                  [target[0]],
                                  tf.float32,
                                  stateful=False)
        gen_audio = tf.py_func(self.spectral.audio_from_mag_spec,
                               [gen_mag_spec[0]],
                               tf.float32,
                               stateful=False)

        input_audio = tf.reshape(input_audio, [1, -1, 1, 1])
        target_audio = tf.reshape(target_audio, [1, -1, 1, 1])
        gen_audio = tf.reshape(gen_audio, [1, -1, 1, 1])

        tf.summary.audio('input_audio', input_audio[:, :, 0, :], self.audio_fs)
        tf.summary.audio('target_audio', target_audio[:, :, 0, :],
                         self.audio_fs)
        tf.summary.audio('target_x_wav', x_wav[:, :, 0, :], self.audio_fs)
        tf.summary.audio('gen_audio', gen_audio[:, :, 0, :], self.audio_fs)
        tf.summary.scalar('gen_loss_total', gen_loss)
        tf.summary.scalar('gen_loss_L1', gen_loss_L1)
        tf.summary.scalar('gen_loss_GAN', gen_loss_GAN)
        tf.summary.scalar('disc_loss', discrim_loss)

        #image summaries
        tf.summary.image('input_melspec', tf.image.rot90(x_mel_spec))
        tf.summary.image('input_magspec', tf.image.rot90(x))
        tf.summary.image('generated_magspec', tf.image.rot90(gen_mag_spec))
        tf.summary.image('target_magspec', tf.image.rot90(target))