Пример #1
0
def load_deconv_stack(hparams, batch_size=1, mel_length=320, num_mel=80):
    wn = wavenet.Wavenet(hparams)
    mel_ph = tf.placeholder(tf.float32,
                            shape=[batch_size, mel_length, num_mel])
    ds_dict = wn.deconv_stack({'mel': mel_ph})
    ds_dict.update({'mel_in': mel_ph})
    return ds_dict
Пример #2
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    if args.log_root:
        if args.config is None:
            raise RuntimeError('No config json specified.')
        tf.logging.info('using config form {}'.format(args.config))
        with open(args.config, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)
        logdir_name = config_str.get_config_time_str(hparams, 'wavenet', EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
        os.makedirs(logdir, exist_ok=True)
        shutil.copy(args.config, logdir)
    else:
        logdir = args.logdir
        config_json = glob.glob(os.path.join(logdir, '*.json'))[0]
        tf.logging.info('using config form {}'.format(config_json))
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)
    tf.logging.info('Saving to {}'.format(logdir))

    wn = wavenet.Wavenet(hparams, os.path.abspath(os.path.expanduser(args.train_path)))
    # At training, wavenet never see the output values of parallel wavenet
    # if clip_quant_scale is not used in parallel wavenet.
    # Add noise to wavenet input to remedy this.
    add_noise = getattr(hparams, 'add_noise', False)

    def _data_dep_init():
        # slim.learning.train runs init_fn earlier than start_queue_runner
        # so the the function got dead locker if use the `input_dict` in L76 as input
        inputs_val = reader.get_init_batch(
            wn.train_path, batch_size=args.total_batch_size, seq_len=wn.wave_length)
        wave_data = inputs_val['wav']
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'wav': tf.placeholder(dtype=tf.float32, shape=wave_data.shape),
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)}

        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        init_ff_dict = wn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(
                init_ff_dict, feed_dict={_inputs_dict['wav']: wave_data,
                                         _inputs_dict['mel']: mel_data})
            init_out_params = init_out['out_params']
            if wn.loss_type == 'mol':
                _, mean, log_scale = np.split(init_out_params, 3, axis=2)
                scale = np.exp(np.maximum(log_scale, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(scale, 'scale')
            elif wn.loss_type == 'gauss':
                mean, log_std = np.split(init_out_params, 2, axis=2)
                std = np.exp(np.maximum(log_std, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(std, 'std')
            tf.logging.info('Done Calculate initial statistics.')
        return callback

    def _model_fn(_inputs_dict):
        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        ff_dict = wn.feed_forward(_inputs_dict)
        ff_dict.update(encode_dict)
        loss_dict = wn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones, clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost', ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = wn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn, [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        summaries.update(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(wn.learning_rate_schedule[0])
            for key, value in wn.learning_rate_schedule.items():
                lr = tf.cond(
                    tf.less(global_step, key), lambda: lr, lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step)

            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=tf.trainable_variables())
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars, global_step=global_step))
            update_ops.append(ema.apply(tf.trainable_variables()))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')
        data_dep_init_fn = _data_dep_init()

        slim.learning.train(
            train_tensor,
            logdir=logdir,
            number_of_steps=wn.num_iters,
            summary_op=summary_op,
            global_step=global_step,
            log_every_n_steps=100,
            save_summaries_secs=600,
            save_interval_secs=3600,
            session_config=session_config,
            init_fn=data_dep_init_fn)
Пример #3
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    ###
    # get teacher info.
    ###
    teacher_dir = utils.shell_path(args.teacher_dir)
    assert tf.gfile.IsDirectory(teacher_dir)
    json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json'))
    assert len(json_in_dir) == 1
    te_json = json_in_dir[0]
    te_ckpt = tf.train.latest_checkpoint(teacher_dir)
    assert tf.train.checkpoint_exists(te_ckpt)

    with open(te_json, 'rt') as F:
        configs = json.load(F)
    te_hparams = Namespace(**configs)
    teacher = wavenet.Wavenet(te_hparams)

    ###
    # get student info.
    ###
    if args.config is None:
        raise RuntimeError('No config json specified.')
    with open(args.config, 'rt') as F:
        configs = json.load(F)
    st_hparams = Namespace(**configs)
    pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher,
                                           args.train_path)

    def _data_dep_init():
        inputs_val = reader.get_init_batch(pwn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=pwn.wave_length)
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        init_ff_dict = pwn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Running data dependent initialization '
                            'for weight normalization')
            init_out = session.run(init_ff_dict,
                                   feed_dict={_inputs_dict['mel']: mel_data})
            new_x = init_out['x']
            mean = init_out['mean_tot']
            scale = init_out['scale_tot']
            _init_logging(new_x, 'new_x')
            _init_logging(mean, 'mean')
            _init_logging(scale, 'scale')
            tf.logging.info('Done data dependent initialization '
                            'for weight normalization')

        return callback

    def _model_fn(_inputs_dict):
        ff_dict = pwn.feed_forward(_inputs_dict)
        ff_dict.update(_inputs_dict)
        loss_dict = pwn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

        tf.summary.scalar("kl_loss", loss_dict['kl_loss'])
        tf.summary.scalar("H_Ps", loss_dict['H_Ps'])
        tf.summary.scalar("H_Ps_Pt", loss_dict['H_Ps_Pt'])
        if 'power_loss' in loss_dict:
            tf.summary.scalar('power_loss', loss_dict['power_loss'])
        if 'contrastive_loss' in loss_dict:
            tf.summary.scalar('contrastive_loss',
                              loss_dict['contrastive_loss'])

    if args.log_root:
        logdir_name = config_str.get_config_time_str(st_hparams,
                                                     'parallel_wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
    else:
        logdir = args.logdir
    tf.logging.info('Saving to {}'.format(logdir))

    os.makedirs(logdir, exist_ok=True)
    shutil.copy(args.config, logdir)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = pwn.get_batch(clone_batch_size)
            # get a mel batch not corresponding to the wave batch.
            # if contrastive loss is not used, this input operation will not be evaluated.
            inputs_dict['mel_rand'] = pwn.get_batch(clone_batch_size)['mel']

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        ###
        # variables to train
        ###
        st_vars = [
            var for var in tf.trainable_variables() if 'iaf' in var.name
        ]

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(pwn.learning_rate_schedule[0])
            for key, value in pwn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=st_vars)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(st_vars))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        ###
        # restore teacher
        ###
        te_vars = [
            var for var in tf.trainable_variables() if 'iaf' not in var.name
        ]
        # teacher use EMA
        te_vars = {
            '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv
            for tv in te_vars
        }
        restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
            te_ckpt, te_vars)
        data_dep_init_fn = _data_dep_init()

        def group_init_fn(session):
            restore_init_fn(session)
            data_dep_init_fn(session)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=pwn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=group_init_fn)
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)

    ###
    # get teacher info.
    ###
    teacher_dir = utils.shell_path(args.teacher_dir)
    assert tf.gfile.IsDirectory(teacher_dir)
    json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json'))
    assert len(json_in_dir) == 1
    te_json = json_in_dir[0]
    te_ckpt = tf.train.latest_checkpoint(teacher_dir)
    assert tf.train.checkpoint_exists(te_ckpt)

    with open(te_json, 'rt') as F:
        configs = json.load(F)
    te_hparams = Namespace(**configs)
    teacher = wavenet.Wavenet(te_hparams)

    ###
    # get student info.
    ###
    if args.config is None:
        raise RuntimeError('No config json specified.')
    with open(args.config, 'rt') as F:
        configs = json.load(F)
    st_hparams = Namespace(**configs)
    pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher,
                                           args.train_path)

    logdir = args.logdir
    tf.logging.info('Saving to {}'.format(logdir))

    os.makedirs(logdir, exist_ok=True)
    shutil.copy(args.config, logdir)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        worker_replicas = args.worker_replicas
        assert total_batch_size % worker_replicas == 0
        worker_batch_size = int(total_batch_size / worker_replicas)

        # Run the Reader on the CPU
        cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
        if args.ps_tasks:
            cpu_device = "/job:worker/cpu:0"

        with tf.device(cpu_device):
            inputs_dict = pwn.get_batch(worker_batch_size)

        with tf.device(
                tf.train.replica_device_setter(ps_tasks=args.ps_tasks,
                                               merge_devices=True)):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

            lr = tf.constant(pwn.learning_rate_schedule[0])
            for key, value in pwn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))

            tf.summary.scalar("learning_rate", lr)
            # build the model graph
            ff_dict = pwn.feed_forward(inputs_dict, use_log_scale=False)
            ff_dict.update(inputs_dict)
            loss_dict = pwn.calculate_loss(ff_dict)
            loss = loss_dict['loss']
            tf.summary.scalar("train_loss", loss)
            tf.summary.scalar("kl_loss", loss_dict['kl_loss'])
            tf.summary.scalar("H_Ps", loss_dict['H_Ps'])
            tf.summary.scalar("H_Ps_Pt", loss_dict['H_Ps_Pt'])
            if 'power_loss' in loss_dict:
                tf.summary.scalar('power_loss', loss_dict['power_loss'])

            ###
            # restore teacher
            ###
            te_vars = slim.get_variables_to_restore(
                exclude=['global_step', 'iaf'])
            # teacher use EMA
            te_vars = {
                '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv
                for tv in te_vars
            }
            restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
                te_ckpt, te_vars)

            ###
            # variables to train
            ###
            st_vars = [
                var for var in tf.trainable_variables() if 'iaf' in var.name
            ]

            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            opt = tf.train.SyncReplicasOptimizer(
                tf.train.AdamOptimizer(lr, epsilon=1e-8),
                worker_replicas,
                total_num_replicas=worker_replicas,
                variable_averages=ema,
                variables_to_average=st_vars)

            train_op = slim.learning.create_train_op(
                total_loss=loss,
                variables_to_train=st_vars,
                optimizer=opt,
                global_step=global_step,
                colocate_gradients_with_ops=True)

            session_config = tf.ConfigProto(allow_soft_placement=True)

            is_chief = (args.task == 0)
            local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

            slim.learning.train(train_op=train_op,
                                logdir=logdir,
                                is_chief=is_chief,
                                master=args.master,
                                number_of_steps=pwn.num_iters,
                                global_step=global_step,
                                log_every_n_steps=250,
                                local_init_op=local_init_op,
                                save_interval_secs=3600,
                                sync_optimizer=opt,
                                session_config=session_config,
                                init_fn=restore_init_fn)
Пример #5
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)

    if args.config is None:
        raise RuntimeError('No config json specified.')
    with open(args.config, 'rt') as F:
        configs = json.load(F)
    hparams = Namespace(**configs)

    wn = wavenet.Wavenet(hparams, args.train_path)

    logdir = args.logdir
    tf.logging.info('Saving to {}'.format(logdir))

    os.makedirs(logdir, exist_ok=True)
    shutil.copy(args.config, logdir)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        worker_replicas = args.worker_replicas
        assert total_batch_size % worker_replicas == 0
        worker_batch_size = int(total_batch_size / worker_replicas)

        # Run the Reader on the CPU
        cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
        if args.ps_tasks:
            cpu_device = "/job:worker/cpu:0"

        with tf.device(cpu_device):
            inputs_dict = wn.get_batch(worker_batch_size)

        with tf.device(
                tf.train.replica_device_setter(ps_tasks=args.ps_tasks,
                                               merge_devices=True)):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

            lr = tf.constant(wn.learning_rate_schedule[0])
            for key, value in wn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            tf.summary.scalar("learning_rate", lr)

            # build the model graph
            ff_dict = wn.feed_forward(inputs_dict)
            loss_dict = wn.calculate_loss(ff_dict)
            loss = loss_dict['loss']
            tf.summary.scalar("train_loss", loss)

            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            opt = tf.train.SyncReplicasOptimizer(
                tf.train.AdamOptimizer(lr, epsilon=1e-8),
                worker_replicas,
                total_num_replicas=worker_replicas,
                variable_averages=ema,
                variables_to_average=tf.trainable_variables())

            train_op = slim.learning.create_train_op(
                total_loss=loss,
                optimizer=opt,
                global_step=global_step,
                colocate_gradients_with_ops=True)

            session_config = tf.ConfigProto(allow_soft_placement=True)

            is_chief = (args.task == 0)
            local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

            slim.learning.train(
                train_op=train_op,
                logdir=logdir,
                is_chief=is_chief,
                master=args.master,
                number_of_steps=wn.num_iters,
                global_step=global_step,
                log_every_n_steps=250,
                local_init_op=local_init_op,
                save_interval_secs=3600,
                sync_optimizer=opt,
                session_config=session_config,
            )
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    ###
    # get teacher info.
    ###
    teacher_dir = utils.shell_path(args.teacher_dir)
    assert tf.gfile.IsDirectory(teacher_dir)
    json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json'))
    assert len(json_in_dir) == 1
    te_json = json_in_dir[0]
    te_ckpt = tf.train.latest_checkpoint(teacher_dir)
    assert tf.train.checkpoint_exists(te_ckpt)

    with open(te_json, 'rt') as F:
        configs = json.load(F)
    te_hparams = Namespace(**configs)
    setattr(te_hparams, 'use_as_teacher', True)
    teacher = wavenet.Wavenet(te_hparams)

    ###
    # get student info.
    ###
    if args.log_root:
        if args.config is None:
            raise RuntimeError('No config json specified.')
        config_json = args.config
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        st_hparams = Namespace(**configs)
        logdir_name = config_str.get_config_time_str(st_hparams,
                                                     'parallel_wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
        os.makedirs(logdir, exist_ok=True)
        shutil.copy(config_json, logdir)
    else:
        logdir = args.logdir
        config_json = glob.glob(os.path.join(logdir, '*.json'))[0]
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        st_hparams = Namespace(**configs)

    enhance_log.add_log_file(logdir)
    if not args.log_root:
        tf.logging.info('Continue running\n\n')
    tf.logging.info('using config form {}'.format(config_json))
    tf.logging.info('Saving to {}'.format(logdir))

    pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher,
                                           args.train_path)
    pwn_config_str = enhance_log.instance_attr_to_str(pwn)
    teacher_config_str = enhance_log.instance_attr_to_str(teacher)
    tf.logging.info('\n' + pwn_config_str)
    tf.logging.info('\nteacher form {}\n'.format(teacher_dir) +
                    teacher_config_str)

    def _data_dep_init():
        inputs_val = reader.get_init_batch(pwn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=pwn.wave_length)
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        init_ff_dict = pwn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(init_ff_dict,
                                   feed_dict={_inputs_dict['mel']: mel_data})
            new_x = init_out['x']
            mean = init_out['mean_tot']
            scale = init_out['scale_tot']
            _init_logging(new_x, 'new_x')
            _init_logging(mean, 'mean')
            _init_logging(scale, 'scale')
            tf.logging.info('Done Calculate initial statistics.')

        return callback

    def _trans_conv_init_from_teacher(te_vars, st_vars):
        """
        Initialize the separate iaf transposed convolution stacks or shared transposed
        convolution stack with the teacher's transposed convolution stack.
        """
        te_trans_conv_var_names = [
            var.name for var in te_vars if pwn.upsample_conv_name in var.name
        ]
        te_trans_conv_vars = [
            var for var in te_vars if var.name in te_trans_conv_var_names
        ]
        st_trans_conv_vars_flow_nested = []
        for te_tcvn in te_trans_conv_var_names:
            st_tcv_for_flows = []
            for var in st_vars:
                if var.name.endswith(te_tcvn):
                    st_tcv_for_flows.append(var)
            st_trans_conv_vars_flow_nested.append(st_tcv_for_flows)

        assert len(te_trans_conv_vars) == len(st_trans_conv_vars_flow_nested)

        assign_ops = []
        for te_tcv, st_tcv_for_flows in zip(te_trans_conv_vars,
                                            st_trans_conv_vars_flow_nested):
            for st_tcv in st_tcv_for_flows:
                assign_ops.append(tf.assign(st_tcv, te_tcv))

        def assign_fn(session):
            tf.logging.info('Load transposed convolution weights form teacher')
            session.run(assign_ops)
            tf.logging.info(
                'Done load transposed convolution weights form teacher')

        return assign_fn

    def _model_fn(_inputs_dict):
        ff_dict = pwn.feed_forward(_inputs_dict)
        ff_dict.update(_inputs_dict)
        loss_dict = pwn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

        for loss_key, loss_val in loss_dict.items():
            tf.summary.scalar(loss_key, loss_val)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = pwn.get_batch(clone_batch_size)
            # get a mel batch not corresponding to the wave batch.
            # if contrastive loss is not used, this input operation will not be evaluated.
            inputs_dict['mel_rand'] = pwn.get_batch(clone_batch_size)['mel']

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        ###
        # variables to train
        ###
        st_var_list = [
            var for var in tf.trainable_variables() if 'iaf' in var.name
        ]
        filtered_st_var_list = pwn.filter_update_variables(st_var_list)

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(pwn.learning_rate_schedule[0])
            for key, value in pwn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=filtered_st_var_list)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(filtered_st_var_list))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        ###
        # restore teacher and other init ops
        ###
        te_var_list = [
            var for var in tf.trainable_variables() if 'iaf' not in var.name
        ]
        # teacher use EMA
        te_var_shardow_dict = {
            '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv
            for tv in te_var_list
        }
        restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
            te_ckpt, te_var_shardow_dict)
        data_dep_init_fn = _data_dep_init()
        share_trans_conv_init_fn = _trans_conv_init_from_teacher(
            te_var_list, st_var_list)

        def group_init_fn(session):
            # the order of the init functions is important, don't change it.
            restore_init_fn(session)
            data_dep_init_fn(session)
            share_trans_conv_init_fn(session)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=pwn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=group_init_fn)
Пример #7
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    if args.config is None:
        raise RuntimeError('No config json specified.')
    with open(args.config, 'rt') as F:
        configs = json.load(F)
    hparams = Namespace(**configs)

    wn = wavenet.Wavenet(hparams, args.train_path)

    def _model_fn(_inputs_dict):
        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        ff_dict = wn.feed_forward(_inputs_dict)
        ff_dict.update(encode_dict)
        loss_dict = wn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

    logdir = args.logdir
    tf.logging.info('Saving to {}'.format(logdir))

    os.makedirs(logdir, exist_ok=True)
    shutil.copy(args.config, logdir)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = wn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(wn.learning_rate_schedule[0])
            for key, value in wn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)

            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=tf.trainable_variables())
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(tf.trainable_variables()))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(
            train_tensor,
            logdir=logdir,
            number_of_steps=wn.num_iters,
            summary_op=summary_op,
            global_step=global_step,
            log_every_n_steps=100,
            save_summaries_secs=600,
            save_interval_secs=3600,
            session_config=session_config,
        )