def load_deconv_stack(hparams, batch_size=1, mel_length=320, num_mel=80): wn = wavenet.Wavenet(hparams) mel_ph = tf.placeholder(tf.float32, shape=[batch_size, mel_length, num_mel]) ds_dict = wn.deconv_stack({'mel': mel_ph}) ds_dict.update({'mel_in': mel_ph}) return ds_dict
def train(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id tf.logging.set_verbosity(args.log) clone_on_cpu = args.gpu_id == '' num_clones = len(args.gpu_id.split(',')) if args.log_root: if args.config is None: raise RuntimeError('No config json specified.') tf.logging.info('using config form {}'.format(args.config)) with open(args.config, 'rt') as F: configs = json.load(F) hparams = Namespace(**configs) logdir_name = config_str.get_config_time_str(hparams, 'wavenet', EXP_TAG) logdir = os.path.join(args.log_root, logdir_name) os.makedirs(logdir, exist_ok=True) shutil.copy(args.config, logdir) else: logdir = args.logdir config_json = glob.glob(os.path.join(logdir, '*.json'))[0] tf.logging.info('using config form {}'.format(config_json)) with open(config_json, 'rt') as F: configs = json.load(F) hparams = Namespace(**configs) tf.logging.info('Saving to {}'.format(logdir)) wn = wavenet.Wavenet(hparams, os.path.abspath(os.path.expanduser(args.train_path))) # At training, wavenet never see the output values of parallel wavenet # if clip_quant_scale is not used in parallel wavenet. # Add noise to wavenet input to remedy this. add_noise = getattr(hparams, 'add_noise', False) def _data_dep_init(): # slim.learning.train runs init_fn earlier than start_queue_runner # so the the function got dead locker if use the `input_dict` in L76 as input inputs_val = reader.get_init_batch( wn.train_path, batch_size=args.total_batch_size, seq_len=wn.wave_length) wave_data = inputs_val['wav'] mel_data = inputs_val['mel'] _inputs_dict = { 'wav': tf.placeholder(dtype=tf.float32, shape=wave_data.shape), 'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)} encode_dict = wn.encode_signal(_inputs_dict) _inputs_dict.update(encode_dict) init_ff_dict = wn.feed_forward(_inputs_dict, init=True) def callback(session): tf.logging.info('Calculate initial statistics.') init_out = session.run( init_ff_dict, feed_dict={_inputs_dict['wav']: wave_data, _inputs_dict['mel']: mel_data}) init_out_params = init_out['out_params'] if wn.loss_type == 'mol': _, mean, log_scale = np.split(init_out_params, 3, axis=2) scale = np.exp(np.maximum(log_scale, -7.0)) _init_logging(mean, 'mean') _init_logging(scale, 'scale') elif wn.loss_type == 'gauss': mean, log_std = np.split(init_out_params, 2, axis=2) std = np.exp(np.maximum(log_std, -7.0)) _init_logging(mean, 'mean') _init_logging(std, 'std') tf.logging.info('Done Calculate initial statistics.') return callback def _model_fn(_inputs_dict): encode_dict = wn.encode_signal(_inputs_dict) _inputs_dict.update(encode_dict) ff_dict = wn.feed_forward(_inputs_dict) ff_dict.update(encode_dict) loss_dict = wn.calculate_loss(ff_dict) loss = loss_dict['loss'] tf.add_to_collection(tf.GraphKeys.LOSSES, loss) with tf.Graph().as_default(): total_batch_size = args.total_batch_size assert total_batch_size % num_clones == 0 clone_batch_size = int(total_batch_size / num_clones) deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, num_ps_tasks=0, worker_job_name='localhost', ps_job_name='localhost') with tf.device(deploy_config.inputs_device()): inputs_dict = wn.get_batch(clone_batch_size) summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, _model_fn, [inputs_dict]) first_clone_scope = deploy_config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) summaries.update(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) with tf.device(deploy_config.variables_device()): global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0), trainable=False) with tf.device(deploy_config.optimizer_device()): lr = tf.constant(wn.learning_rate_schedule[0]) for key, value in wn.learning_rate_schedule.items(): lr = tf.cond( tf.less(global_step, key), lambda: lr, lambda: tf.constant(value)) summaries.add(tf.summary.scalar("learning_rate", lr)) optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8) ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step) loss, clone_grads_vars = model_deploy.optimize_clones( clones, optimizer, var_list=tf.trainable_variables()) update_ops.append( optimizer.apply_gradients(clone_grads_vars, global_step=global_step)) update_ops.append(ema.apply(tf.trainable_variables())) summaries.add(tf.summary.scalar("train_loss", loss)) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(loss, name='train_op') session_config = tf.ConfigProto(allow_soft_placement=True) session_config.gpu_options.allow_growth = True summary_op = tf.summary.merge(list(summaries), name='summary_op') data_dep_init_fn = _data_dep_init() slim.learning.train( train_tensor, logdir=logdir, number_of_steps=wn.num_iters, summary_op=summary_op, global_step=global_step, log_every_n_steps=100, save_summaries_secs=600, save_interval_secs=3600, session_config=session_config, init_fn=data_dep_init_fn)
def train(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id tf.logging.set_verbosity(args.log) clone_on_cpu = args.gpu_id == '' num_clones = len(args.gpu_id.split(',')) ### # get teacher info. ### teacher_dir = utils.shell_path(args.teacher_dir) assert tf.gfile.IsDirectory(teacher_dir) json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json')) assert len(json_in_dir) == 1 te_json = json_in_dir[0] te_ckpt = tf.train.latest_checkpoint(teacher_dir) assert tf.train.checkpoint_exists(te_ckpt) with open(te_json, 'rt') as F: configs = json.load(F) te_hparams = Namespace(**configs) teacher = wavenet.Wavenet(te_hparams) ### # get student info. ### if args.config is None: raise RuntimeError('No config json specified.') with open(args.config, 'rt') as F: configs = json.load(F) st_hparams = Namespace(**configs) pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher, args.train_path) def _data_dep_init(): inputs_val = reader.get_init_batch(pwn.train_path, batch_size=args.total_batch_size, seq_len=pwn.wave_length) mel_data = inputs_val['mel'] _inputs_dict = { 'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape) } init_ff_dict = pwn.feed_forward(_inputs_dict, init=True) def callback(session): tf.logging.info('Running data dependent initialization ' 'for weight normalization') init_out = session.run(init_ff_dict, feed_dict={_inputs_dict['mel']: mel_data}) new_x = init_out['x'] mean = init_out['mean_tot'] scale = init_out['scale_tot'] _init_logging(new_x, 'new_x') _init_logging(mean, 'mean') _init_logging(scale, 'scale') tf.logging.info('Done data dependent initialization ' 'for weight normalization') return callback def _model_fn(_inputs_dict): ff_dict = pwn.feed_forward(_inputs_dict) ff_dict.update(_inputs_dict) loss_dict = pwn.calculate_loss(ff_dict) loss = loss_dict['loss'] tf.add_to_collection(tf.GraphKeys.LOSSES, loss) tf.summary.scalar("kl_loss", loss_dict['kl_loss']) tf.summary.scalar("H_Ps", loss_dict['H_Ps']) tf.summary.scalar("H_Ps_Pt", loss_dict['H_Ps_Pt']) if 'power_loss' in loss_dict: tf.summary.scalar('power_loss', loss_dict['power_loss']) if 'contrastive_loss' in loss_dict: tf.summary.scalar('contrastive_loss', loss_dict['contrastive_loss']) if args.log_root: logdir_name = config_str.get_config_time_str(st_hparams, 'parallel_wavenet', EXP_TAG) logdir = os.path.join(args.log_root, logdir_name) else: logdir = args.logdir tf.logging.info('Saving to {}'.format(logdir)) os.makedirs(logdir, exist_ok=True) shutil.copy(args.config, logdir) with tf.Graph().as_default(): total_batch_size = args.total_batch_size assert total_batch_size % num_clones == 0 clone_batch_size = int(total_batch_size / num_clones) deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, num_ps_tasks=0, worker_job_name='localhost', ps_job_name='localhost') with tf.device(deploy_config.inputs_device()): inputs_dict = pwn.get_batch(clone_batch_size) # get a mel batch not corresponding to the wave batch. # if contrastive loss is not used, this input operation will not be evaluated. inputs_dict['mel_rand'] = pwn.get_batch(clone_batch_size)['mel'] summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, _model_fn, [inputs_dict]) first_clone_scope = deploy_config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) summaries.update( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) with tf.device(deploy_config.variables_device()): global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0), trainable=False) ### # variables to train ### st_vars = [ var for var in tf.trainable_variables() if 'iaf' in var.name ] with tf.device(deploy_config.optimizer_device()): lr = tf.constant(pwn.learning_rate_schedule[0]) for key, value in pwn.learning_rate_schedule.items(): lr = tf.cond(tf.less(global_step, key), lambda: lr, lambda: tf.constant(value)) summaries.add(tf.summary.scalar("learning_rate", lr)) optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8) ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step) loss, clone_grads_vars = model_deploy.optimize_clones( clones, optimizer, var_list=st_vars) update_ops.append( optimizer.apply_gradients(clone_grads_vars, global_step=global_step)) update_ops.append(ema.apply(st_vars)) summaries.add(tf.summary.scalar("train_loss", loss)) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(loss, name='train_op') ### # restore teacher ### te_vars = [ var for var in tf.trainable_variables() if 'iaf' not in var.name ] # teacher use EMA te_vars = { '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv for tv in te_vars } restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn( te_ckpt, te_vars) data_dep_init_fn = _data_dep_init() def group_init_fn(session): restore_init_fn(session) data_dep_init_fn(session) session_config = tf.ConfigProto(allow_soft_placement=True) session_config.gpu_options.allow_growth = True summary_op = tf.summary.merge(list(summaries), name='summary_op') slim.learning.train(train_tensor, logdir=logdir, number_of_steps=pwn.num_iters, summary_op=summary_op, global_step=global_step, log_every_n_steps=100, save_summaries_secs=600, save_interval_secs=3600, session_config=session_config, init_fn=group_init_fn)
def train(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id tf.logging.set_verbosity(args.log) ### # get teacher info. ### teacher_dir = utils.shell_path(args.teacher_dir) assert tf.gfile.IsDirectory(teacher_dir) json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json')) assert len(json_in_dir) == 1 te_json = json_in_dir[0] te_ckpt = tf.train.latest_checkpoint(teacher_dir) assert tf.train.checkpoint_exists(te_ckpt) with open(te_json, 'rt') as F: configs = json.load(F) te_hparams = Namespace(**configs) teacher = wavenet.Wavenet(te_hparams) ### # get student info. ### if args.config is None: raise RuntimeError('No config json specified.') with open(args.config, 'rt') as F: configs = json.load(F) st_hparams = Namespace(**configs) pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher, args.train_path) logdir = args.logdir tf.logging.info('Saving to {}'.format(logdir)) os.makedirs(logdir, exist_ok=True) shutil.copy(args.config, logdir) with tf.Graph().as_default(): total_batch_size = args.total_batch_size worker_replicas = args.worker_replicas assert total_batch_size % worker_replicas == 0 worker_batch_size = int(total_batch_size / worker_replicas) # Run the Reader on the CPU cpu_device = "/job:localhost/replica:0/task:0/cpu:0" if args.ps_tasks: cpu_device = "/job:worker/cpu:0" with tf.device(cpu_device): inputs_dict = pwn.get_batch(worker_batch_size) with tf.device( tf.train.replica_device_setter(ps_tasks=args.ps_tasks, merge_devices=True)): global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0), trainable=False) lr = tf.constant(pwn.learning_rate_schedule[0]) for key, value in pwn.learning_rate_schedule.items(): lr = tf.cond(tf.less(global_step, key), lambda: lr, lambda: tf.constant(value)) tf.summary.scalar("learning_rate", lr) # build the model graph ff_dict = pwn.feed_forward(inputs_dict, use_log_scale=False) ff_dict.update(inputs_dict) loss_dict = pwn.calculate_loss(ff_dict) loss = loss_dict['loss'] tf.summary.scalar("train_loss", loss) tf.summary.scalar("kl_loss", loss_dict['kl_loss']) tf.summary.scalar("H_Ps", loss_dict['H_Ps']) tf.summary.scalar("H_Ps_Pt", loss_dict['H_Ps_Pt']) if 'power_loss' in loss_dict: tf.summary.scalar('power_loss', loss_dict['power_loss']) ### # restore teacher ### te_vars = slim.get_variables_to_restore( exclude=['global_step', 'iaf']) # teacher use EMA te_vars = { '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv for tv in te_vars } restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn( te_ckpt, te_vars) ### # variables to train ### st_vars = [ var for var in tf.trainable_variables() if 'iaf' in var.name ] ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step) opt = tf.train.SyncReplicasOptimizer( tf.train.AdamOptimizer(lr, epsilon=1e-8), worker_replicas, total_num_replicas=worker_replicas, variable_averages=ema, variables_to_average=st_vars) train_op = slim.learning.create_train_op( total_loss=loss, variables_to_train=st_vars, optimizer=opt, global_step=global_step, colocate_gradients_with_ops=True) session_config = tf.ConfigProto(allow_soft_placement=True) is_chief = (args.task == 0) local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op slim.learning.train(train_op=train_op, logdir=logdir, is_chief=is_chief, master=args.master, number_of_steps=pwn.num_iters, global_step=global_step, log_every_n_steps=250, local_init_op=local_init_op, save_interval_secs=3600, sync_optimizer=opt, session_config=session_config, init_fn=restore_init_fn)
def train(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id tf.logging.set_verbosity(args.log) if args.config is None: raise RuntimeError('No config json specified.') with open(args.config, 'rt') as F: configs = json.load(F) hparams = Namespace(**configs) wn = wavenet.Wavenet(hparams, args.train_path) logdir = args.logdir tf.logging.info('Saving to {}'.format(logdir)) os.makedirs(logdir, exist_ok=True) shutil.copy(args.config, logdir) with tf.Graph().as_default(): total_batch_size = args.total_batch_size worker_replicas = args.worker_replicas assert total_batch_size % worker_replicas == 0 worker_batch_size = int(total_batch_size / worker_replicas) # Run the Reader on the CPU cpu_device = "/job:localhost/replica:0/task:0/cpu:0" if args.ps_tasks: cpu_device = "/job:worker/cpu:0" with tf.device(cpu_device): inputs_dict = wn.get_batch(worker_batch_size) with tf.device( tf.train.replica_device_setter(ps_tasks=args.ps_tasks, merge_devices=True)): global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0), trainable=False) lr = tf.constant(wn.learning_rate_schedule[0]) for key, value in wn.learning_rate_schedule.items(): lr = tf.cond(tf.less(global_step, key), lambda: lr, lambda: tf.constant(value)) tf.summary.scalar("learning_rate", lr) # build the model graph ff_dict = wn.feed_forward(inputs_dict) loss_dict = wn.calculate_loss(ff_dict) loss = loss_dict['loss'] tf.summary.scalar("train_loss", loss) ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step) opt = tf.train.SyncReplicasOptimizer( tf.train.AdamOptimizer(lr, epsilon=1e-8), worker_replicas, total_num_replicas=worker_replicas, variable_averages=ema, variables_to_average=tf.trainable_variables()) train_op = slim.learning.create_train_op( total_loss=loss, optimizer=opt, global_step=global_step, colocate_gradients_with_ops=True) session_config = tf.ConfigProto(allow_soft_placement=True) is_chief = (args.task == 0) local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op slim.learning.train( train_op=train_op, logdir=logdir, is_chief=is_chief, master=args.master, number_of_steps=wn.num_iters, global_step=global_step, log_every_n_steps=250, local_init_op=local_init_op, save_interval_secs=3600, sync_optimizer=opt, session_config=session_config, )
def train(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id tf.logging.set_verbosity(args.log) clone_on_cpu = args.gpu_id == '' num_clones = len(args.gpu_id.split(',')) ### # get teacher info. ### teacher_dir = utils.shell_path(args.teacher_dir) assert tf.gfile.IsDirectory(teacher_dir) json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json')) assert len(json_in_dir) == 1 te_json = json_in_dir[0] te_ckpt = tf.train.latest_checkpoint(teacher_dir) assert tf.train.checkpoint_exists(te_ckpt) with open(te_json, 'rt') as F: configs = json.load(F) te_hparams = Namespace(**configs) setattr(te_hparams, 'use_as_teacher', True) teacher = wavenet.Wavenet(te_hparams) ### # get student info. ### if args.log_root: if args.config is None: raise RuntimeError('No config json specified.') config_json = args.config with open(config_json, 'rt') as F: configs = json.load(F) st_hparams = Namespace(**configs) logdir_name = config_str.get_config_time_str(st_hparams, 'parallel_wavenet', EXP_TAG) logdir = os.path.join(args.log_root, logdir_name) os.makedirs(logdir, exist_ok=True) shutil.copy(config_json, logdir) else: logdir = args.logdir config_json = glob.glob(os.path.join(logdir, '*.json'))[0] with open(config_json, 'rt') as F: configs = json.load(F) st_hparams = Namespace(**configs) enhance_log.add_log_file(logdir) if not args.log_root: tf.logging.info('Continue running\n\n') tf.logging.info('using config form {}'.format(config_json)) tf.logging.info('Saving to {}'.format(logdir)) pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher, args.train_path) pwn_config_str = enhance_log.instance_attr_to_str(pwn) teacher_config_str = enhance_log.instance_attr_to_str(teacher) tf.logging.info('\n' + pwn_config_str) tf.logging.info('\nteacher form {}\n'.format(teacher_dir) + teacher_config_str) def _data_dep_init(): inputs_val = reader.get_init_batch(pwn.train_path, batch_size=args.total_batch_size, seq_len=pwn.wave_length) mel_data = inputs_val['mel'] _inputs_dict = { 'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape) } init_ff_dict = pwn.feed_forward(_inputs_dict, init=True) def callback(session): tf.logging.info('Calculate initial statistics.') init_out = session.run(init_ff_dict, feed_dict={_inputs_dict['mel']: mel_data}) new_x = init_out['x'] mean = init_out['mean_tot'] scale = init_out['scale_tot'] _init_logging(new_x, 'new_x') _init_logging(mean, 'mean') _init_logging(scale, 'scale') tf.logging.info('Done Calculate initial statistics.') return callback def _trans_conv_init_from_teacher(te_vars, st_vars): """ Initialize the separate iaf transposed convolution stacks or shared transposed convolution stack with the teacher's transposed convolution stack. """ te_trans_conv_var_names = [ var.name for var in te_vars if pwn.upsample_conv_name in var.name ] te_trans_conv_vars = [ var for var in te_vars if var.name in te_trans_conv_var_names ] st_trans_conv_vars_flow_nested = [] for te_tcvn in te_trans_conv_var_names: st_tcv_for_flows = [] for var in st_vars: if var.name.endswith(te_tcvn): st_tcv_for_flows.append(var) st_trans_conv_vars_flow_nested.append(st_tcv_for_flows) assert len(te_trans_conv_vars) == len(st_trans_conv_vars_flow_nested) assign_ops = [] for te_tcv, st_tcv_for_flows in zip(te_trans_conv_vars, st_trans_conv_vars_flow_nested): for st_tcv in st_tcv_for_flows: assign_ops.append(tf.assign(st_tcv, te_tcv)) def assign_fn(session): tf.logging.info('Load transposed convolution weights form teacher') session.run(assign_ops) tf.logging.info( 'Done load transposed convolution weights form teacher') return assign_fn def _model_fn(_inputs_dict): ff_dict = pwn.feed_forward(_inputs_dict) ff_dict.update(_inputs_dict) loss_dict = pwn.calculate_loss(ff_dict) loss = loss_dict['loss'] tf.add_to_collection(tf.GraphKeys.LOSSES, loss) for loss_key, loss_val in loss_dict.items(): tf.summary.scalar(loss_key, loss_val) with tf.Graph().as_default(): total_batch_size = args.total_batch_size assert total_batch_size % num_clones == 0 clone_batch_size = int(total_batch_size / num_clones) deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, num_ps_tasks=0, worker_job_name='localhost', ps_job_name='localhost') with tf.device(deploy_config.inputs_device()): inputs_dict = pwn.get_batch(clone_batch_size) # get a mel batch not corresponding to the wave batch. # if contrastive loss is not used, this input operation will not be evaluated. inputs_dict['mel_rand'] = pwn.get_batch(clone_batch_size)['mel'] summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, _model_fn, [inputs_dict]) first_clone_scope = deploy_config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) summaries.update( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) with tf.device(deploy_config.variables_device()): global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0), trainable=False) ### # variables to train ### st_var_list = [ var for var in tf.trainable_variables() if 'iaf' in var.name ] filtered_st_var_list = pwn.filter_update_variables(st_var_list) with tf.device(deploy_config.optimizer_device()): lr = tf.constant(pwn.learning_rate_schedule[0]) for key, value in pwn.learning_rate_schedule.items(): lr = tf.cond(tf.less(global_step, key), lambda: lr, lambda: tf.constant(value)) summaries.add(tf.summary.scalar("learning_rate", lr)) optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8) ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step) loss, clone_grads_vars = model_deploy.optimize_clones( clones, optimizer, var_list=filtered_st_var_list) update_ops.append( optimizer.apply_gradients(clone_grads_vars, global_step=global_step)) update_ops.append(ema.apply(filtered_st_var_list)) summaries.add(tf.summary.scalar("train_loss", loss)) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(loss, name='train_op') ### # restore teacher and other init ops ### te_var_list = [ var for var in tf.trainable_variables() if 'iaf' not in var.name ] # teacher use EMA te_var_shardow_dict = { '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv for tv in te_var_list } restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn( te_ckpt, te_var_shardow_dict) data_dep_init_fn = _data_dep_init() share_trans_conv_init_fn = _trans_conv_init_from_teacher( te_var_list, st_var_list) def group_init_fn(session): # the order of the init functions is important, don't change it. restore_init_fn(session) data_dep_init_fn(session) share_trans_conv_init_fn(session) session_config = tf.ConfigProto(allow_soft_placement=True) session_config.gpu_options.allow_growth = True summary_op = tf.summary.merge(list(summaries), name='summary_op') slim.learning.train(train_tensor, logdir=logdir, number_of_steps=pwn.num_iters, summary_op=summary_op, global_step=global_step, log_every_n_steps=100, save_summaries_secs=600, save_interval_secs=3600, session_config=session_config, init_fn=group_init_fn)
def train(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id tf.logging.set_verbosity(args.log) clone_on_cpu = args.gpu_id == '' num_clones = len(args.gpu_id.split(',')) if args.config is None: raise RuntimeError('No config json specified.') with open(args.config, 'rt') as F: configs = json.load(F) hparams = Namespace(**configs) wn = wavenet.Wavenet(hparams, args.train_path) def _model_fn(_inputs_dict): encode_dict = wn.encode_signal(_inputs_dict) _inputs_dict.update(encode_dict) ff_dict = wn.feed_forward(_inputs_dict) ff_dict.update(encode_dict) loss_dict = wn.calculate_loss(ff_dict) loss = loss_dict['loss'] tf.add_to_collection(tf.GraphKeys.LOSSES, loss) logdir = args.logdir tf.logging.info('Saving to {}'.format(logdir)) os.makedirs(logdir, exist_ok=True) shutil.copy(args.config, logdir) with tf.Graph().as_default(): total_batch_size = args.total_batch_size assert total_batch_size % num_clones == 0 clone_batch_size = int(total_batch_size / num_clones) deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, num_ps_tasks=0, worker_job_name='localhost', ps_job_name='localhost') with tf.device(deploy_config.inputs_device()): inputs_dict = wn.get_batch(clone_batch_size) summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, _model_fn, [inputs_dict]) first_clone_scope = deploy_config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) summaries.update( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) with tf.device(deploy_config.variables_device()): global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0), trainable=False) with tf.device(deploy_config.optimizer_device()): lr = tf.constant(wn.learning_rate_schedule[0]) for key, value in wn.learning_rate_schedule.items(): lr = tf.cond(tf.less(global_step, key), lambda: lr, lambda: tf.constant(value)) summaries.add(tf.summary.scalar("learning_rate", lr)) optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8) ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=global_step) loss, clone_grads_vars = model_deploy.optimize_clones( clones, optimizer, var_list=tf.trainable_variables()) update_ops.append( optimizer.apply_gradients(clone_grads_vars, global_step=global_step)) update_ops.append(ema.apply(tf.trainable_variables())) summaries.add(tf.summary.scalar("train_loss", loss)) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(loss, name='train_op') session_config = tf.ConfigProto(allow_soft_placement=True) session_config.gpu_options.allow_growth = True summary_op = tf.summary.merge(list(summaries), name='summary_op') slim.learning.train( train_tensor, logdir=logdir, number_of_steps=wn.num_iters, summary_op=summary_op, global_step=global_step, log_every_n_steps=100, save_summaries_secs=600, save_interval_secs=3600, session_config=session_config, )