def train(infer_z, noisy_y, C, img_label): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #indices, images, labels = cifar10.distorted_inputs() indices, images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs( return_T_flag=True, noise_ratio=FLAGS.noise_ratio) indices = indices[:, 0] # rank 2 --> rank 1, i.e., (batch_size,1) --> (batch_size,) # Build a Graph that computes the logits predictions from the # inference model. is_training = tf.placeholder(tf.bool, shape=(), name='bn_flag') logits = cifar10.inference(images, training=is_training) preds = tf.nn.softmax(logits) # approximate Gibbs sampling T = tf.placeholder(tf.float32, shape=[cifar10.NUM_CLASSES, cifar10.NUM_CLASSES], name='transition') if FLAGS.groudtruth: unnorm_probs = preds * tf.gather(tf.transpose(T_tru, [1, 0]), labels) else: unnorm_probs = preds * tf.gather(tf.transpose(T, [1, 0]), labels) probs = unnorm_probs / tf.reduce_sum( unnorm_probs, axis=1, keepdims=True) sampler = OneHotCategorical(probs=probs) labels_ = tf.stop_gradient(tf.argmax(sampler.sample(), axis=1)) loss = cifar10.loss(logits, labels_) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Calculate prediction # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op # if you only want to inspect acc of each batch, just sess run acc_op[0] acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1)) tf.summary.scalar('training accuracy', acc_op[0]) #### build scalffold for MonitoredTrainingSession to restore the variables you wish variables_to_restore = [] #variables_to_restore += [var for var in tf.trainable_variables() if 'dense' not in var.name] # if final layer is not included variables_to_restore += tf.trainable_variables( ) # if final layer is included variables_to_restore += [ g for g in tf.global_variables() if 'moving_mean' in g.name or 'moving_variance' in g.name ] for var in variables_to_restore: print(var.name) #variables_to_restore = [] ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( ckpt.model_checkpoint_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs( tf.get_collection('losses')[0]) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, scaffold=scaffold, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: ## initialize some params alpha = 1.0 C_init = C.copy() trans_init = (C + alpha) / np.sum(C + alpha, axis=1, keepdims=True) ## running setting warming_up_step = 20000 step = 0 freq_trans = 200 ### warming up transition with open('T_%.2f.pkl' % FLAGS.noise_ratio) as f: data = pickle.load(f) trans_warming = data[ 2] # trans_init or np.eye(cifar10.NUM_CLASSES) ## record and run exemplars = [] label_trace_exemplars = [] infer_z_probs = dict() trans_before_after_trace = [] while not mon_sess.should_stop(): if step % freq_trans == 0: # update transition matrix in each n steps trans = (C + alpha) / np.sum( C + alpha, axis=1, keepdims=True) if step < warming_up_step: res = mon_sess.run([ train_op, acc_op, global_step, indices, labels, labels_, probs ], feed_dict={ is_training: True, T: trans_warming }) else: res = mon_sess.run([ train_op, acc_op, global_step, indices, labels, labels_, probs ], feed_dict={ is_training: True, T: trans }) #print(res[3].shape) trans_before = (C + alpha) / np.sum( C + alpha, axis=1, keepdims=True) C_before = C.copy() for i in xrange(res[3].shape[0]): ind = res[3][i] #print(noisy_y[ind],res[4][i]) assert noisy_y[ind] == res[4][i] C[infer_z[ind]][noisy_y[ind]] -= 1 assert C[infer_z[ind]][noisy_y[ind]] >= 0 infer_z[ind] = res[5][i] infer_z_probs[ind] = res[6][i] C[infer_z[ind]][noisy_y[ind]] += 1 #print(res[4][i],res[5][i]) trans_after = (C + alpha) / np.sum( C + alpha, axis=1, keepdims=True) C_after = C.copy() trans_gap = np.sum(np.absolute(trans_after - trans_before)) rou = np.sum(C_after - C_before, axis=-1) / np.sum( C_before + alpha, axis=-1) rou_ = np.sum(np.absolute(C_after - C_before), axis=-1) / np.sum(C_before + alpha, axis=-1) trans_bound = np.sum((np.absolute(rou) + rou_) / (1 + rou)) trans_before_after_trace.append([step, trans_gap, trans_bound]) #print(trans_gap, trans_bound) step = res[2] if step % 1000 == 0: print('Counting matrix\n', C) print('Counting matrix\n', C_init) print('Transition matrix\n', trans) print('Transition matrix\n', trans_init) if step % 5000 == 0: exemplars.append([ infer_z.copy().keys(), infer_z.copy().values(), C.copy() ]) if step % FLAGS.max_steps_per_epoch == 0: r_n = 0 all_n = 0 for key in infer_z.keys(): if infer_z[key] == img_label[key]: r_n += 1 all_n += 1 acc = r_n / all_n #print('accuracy: %.2f'%acc) label_trace_exemplars.append( [infer_z.copy(), infer_z_probs.copy(), acc]) if not FLAGS.groudtruth: with open('varC_learnt_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w: pickle.dump(exemplars, w) else: with open('varC_learnt_%.2f_tru.pkl' % FLAGS.noise_ratio, 'w') as w: pickle.dump(exemplars, w) if FLAGS.labeltrace: with open('varC_label_trace_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w: pickle.dump([label_trace_exemplars, img_label], w) with open('varC_transvar_trace_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w: pickle.dump(trans_before_after_trace, w)
def train(T_est, T_inv_est): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() T_est = tf.constant(T_est) T_inv_est = tf.constant(T_inv_est) # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #images, labels = cifar10.distorted_inputs() images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs( return_T_flag=True) # Build a Graph that computes the logits predictions from the # inference model. dropout = tf.constant(0.75) logits = cifar10.inference(images, dropout, dropout_flag=True) # Calculate loss. #loss = loss_forward(logits, labels, T_est) loss = loss_forward(logits, labels, T_tru) #loss = loss_backward(logits, labels, T_inv_est) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op, variable_averages = cifar10.train( loss, global_step, return_variable_averages=True) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #### build scalffold for MonitoredTrainingSession to restore the variables you wish ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) variables_to_restore = variable_averages.variables_to_restore() #print(variables_to_restore) for var_name in variables_to_restore.keys(): if ('logits_T' in var_name) or ('global_step' in var_name): del variables_to_restore[var_name] #print(variables_to_restore) init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( ckpt.model_checkpoint_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, scaffold=scaffold, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: while not mon_sess.should_stop(): res = mon_sess.run([train_op, global_step, T_tru, T_mask_tru]) if res[1] % 1000 == 0: print('Disturbing matrix\n', res[2]) print('Masked structure\n', res[3])
def train(T_fixed, T_init): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #indices, images, labels = cifar10.distorted_inputs() indices, images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs( return_T_flag=True, noise_ratio=FLAGS.noise_ratio) indices = indices[:, 0] # Build a Graph that computes the logits predictions from the # inference model. is_training = tf.placeholder(tf.bool, shape=(), name='bn_flag') logits = cifar10.inference(images, training=is_training) preds = tf.nn.softmax(logits) # fixed adaption layer fixed_adaption_layer = tf.cast(tf.constant(T_fixed), tf.float32) # adaption layer logits_T = tf.get_variable( 'logits_T', shape=[cifar10.NUM_CLASSES, cifar10.NUM_CLASSES], initializer=tf.constant_initializer(np.log(T_init + 1e-8))) adaption_layer = tf.nn.softmax(logits_T) # label adaption is_use = tf.placeholder(tf.bool, shape=(), name='warming_up_flag') adaption = tf.cond(is_use, lambda: fixed_adaption_layer, lambda: adaption_layer) preds_aug = tf.clip_by_value(tf.matmul(preds, adaption), 1e-8, 1.0 - 1e-8) logits_aug = tf.log(preds_aug) # Calculate loss. loss = cifar10.loss(logits_aug, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Calculate prediction # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op # if you only want to inspect acc of each batch, just sess run acc_op[0] acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1)) tf.summary.scalar('training accuracy', acc_op[0]) #### build scalffold for MonitoredTrainingSession to restore the variables you wish variables_to_restore = [] #variables_to_restore += [var for var in tf.trainable_variables() if ('dense' not in var.name and 'logits_T' not in var.name)] variables_to_restore += [ var for var in tf.trainable_variables() if 'logits_T' not in var.name ] variables_to_restore += [ g for g in tf.global_variables() if 'moving_mean' in g.name or 'moving_variance' in g.name ] for var in variables_to_restore: print(var.name) ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( ckpt.model_checkpoint_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs( tf.get_collection('losses')[0]) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, scaffold=scaffold, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: warming_up_step = 32000 step = 0 varT_rec = [] varT_trans_trace = [] while not mon_sess.should_stop(): if step < warming_up_step: res = mon_sess.run([ train_op, acc_op, global_step, fixed_adaption_layer, T_tru, T_mask_tru ], feed_dict={ is_training: True, is_use: True }) else: res = mon_sess.run([ train_op, acc_op, global_step, adaption_layer, T_tru, T_mask_tru ], feed_dict={ is_training: True, is_use: False }) step = res[2] if step % 5000 == 0: varT_rec.append(res[3]) if step == warming_up_step: trans_before = res[3].copy() if step > warming_up_step: trans_after = res[3].copy() trans_gap = np.sum(np.absolute(trans_before - trans_after)) varT_trans_trace.append([step, trans_gap]) with open('varT_learnt_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w: pickle.dump(varT_rec, w) with open('varT_transvar_trace_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w: pickle.dump(varT_trans_trace, w)
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #images, labels = cifar10.distorted_inputs() images, labels, T, T_mask = cifar10.noisy_distorted_inputs( return_T_flag=True) # Build a Graph that computes the logits predictions from the # inference model. dropout = tf.placeholder(tf.float32, name='dropout_rate') logits = cifar10.inference(images, dropout, dropout_flag=True) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: while not mon_sess.should_stop(): #mon_sess.run(train_op,feed_dict={dropout:0.75}) res = mon_sess.run([train_op, global_step, T, T_mask], feed_dict={dropout: 0.75}) if res[1] % 1000 == 0: print('Disturbing matrix\n', res[2]) print('Masked structure\n', res[3])
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #indices, images, labels = cifar10.distorted_inputs() indices, images, labels, T, T_mask = cifar10.noisy_distorted_inputs( noise_ratio=FLAGS.noise_ratio, return_T_flag=True) # Build a Graph that computes the logits predictions from the # inference model. is_training = tf.placeholder(tf.bool, shape=(), name="bn_flag") logits = cifar10.inference(images, training=is_training) # Calculate loss. # loss = cifar10.loss(logits, labels) # Calculate the average cross entropy loss across the batch. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.cast(labels, tf.int64), logits=logits, name='cross_entropy_per_example') loss1 = tf.reduce_mean(cross_entropy, name='cross_entropy') tf.add_to_collection('losses', loss1) # perceptual loss preds = tf.nn.softmax(logits) preds = tf.clip_by_value(preds, 1e-8, 1 - 1e-8) loss2 = tf.reduce_mean(-tf.reduce_sum(preds * tf.log(preds), axis=-1), name='perceptual_certainty_soft') #loss2 = tf.reduce_mean(-tf.reduce_sum(tf.stop_gradient(tf.to_float(tf.one_hot(tf.argmax(preds,axis=-1),depth=cifar10.NUM_CLASSES,axis=-1)))*tf.log(preds),axis=-1), name='perceptual_certainty_hard') tf.add_to_collection('losses', loss2) # l2 loss l2_loss = tf.add_n([ cifar10.WEIGHT_DECAY * tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ], name='l2_loss') tf.add_to_collection('losses', l2_loss) # weighted loss alpha = tf.placeholder(tf.float32, shape=(), name='perceptual_weight') _LoggerHook_loss = alpha * loss1 + (1 - alpha) * loss2 loss = alpha * loss1 + (1 - alpha) * loss2 + l2_loss # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Calculate prediction # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op # if you only want to inspect acc of each batch, just sess run acc_op[0] acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1)) tf.summary.scalar('training accuracy', acc_op[0]) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs( tf.get_collection('losses')[0]) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: while not mon_sess.should_stop(): #mon_sess.run([train_op,acc_op,global_step],feed_dict={is_training:True, alpha: 0.5}) res = mon_sess.run([train_op, acc_op, global_step, T, T_mask], feed_dict={ is_training: True, alpha: 0.5 }) if res[2] % 1000 == 0: print('Disturbing matrix\n', res[3]) print('Masked structure\n', res[4])
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #indices, images, labels = cifar10.distorted_inputs() indices, images, labels, T, T_mask = cifar10.noisy_distorted_inputs( noise_ratio=FLAGS.noise_ratio, return_T_flag=True) # Build a Graph that computes the logits predictions from the # inference model. is_training = tf.placeholder(tf.bool, shape=(), name="bn_flag") logits = cifar10.inference(images, training=is_training) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Calculate prediction # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op # if you only want to inspect acc of each batch, just sess run acc_op[0] acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1)) tf.summary.scalar('training accuracy', acc_op[0]) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs( tf.get_collection('losses')[0]) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: while not mon_sess.should_stop(): #mon_sess.run([train_op,acc_op,global_step],feed_dict={is_training:True}) res = mon_sess.run([train_op, acc_op, global_step, T, T_mask], feed_dict={is_training: True}) if res[2] % 1000 == 0: print('Disturbing matrix\n', res[3]) print('Masked structure\n', res[4])
def train(T_est,T_inv_est): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #indices, images, labels = cifar10.distorted_inputs() indices,images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(return_T_flag=True,noise_ratio=FLAGS.noise_ratio) # Build a Graph that computes the logits predictions from the # inference model. is_training = tf.placeholder(tf.bool,shape=(),name='bn_flag') logits = cifar10.inference(images,training=is_training) T_est = tf.cast(tf.constant(T_est),tf.float32) T_inv_est = tf.cast(tf.constant(T_inv_est),tf.float32) # Calculate loss. if FLAGS.groudtruth: loss = loss_forward(logits, labels, T_tru) else: loss = loss_forward(logits, labels, T_est) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Calculate prediction # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op # if you only want to inspect acc of each batch, just sess run acc_op[0] acc_op = tf.metrics.accuracy(labels, tf.argmax(logits,axis=1)) tf.summary.scalar('training accuracy', acc_op[0]) #### build scalffold for MonitoredTrainingSession to restore the variables you wish variables_to_restore = [] #variables_to_restore += [var for var in tf.trainable_variables() if 'dense' not in var.name] # if final layer is not included variables_to_restore += tf.trainable_variables() # if final layer is included variables_to_restore += [g for g in tf.global_variables() if 'moving_mean' in g.name or 'moving_variance' in g.name] for var in variables_to_restore: print(var.name) ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( ckpt.model_checkpoint_path, variables_to_restore) def InitAssignFn(scaffold,sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(tf.get_collection('losses')[0]) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, scaffold = scaffold, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: while not mon_sess.should_stop(): res = mon_sess.run([train_op,acc_op,global_step,T_tru,T_mask_tru],feed_dict={is_training:True}) if res[2] % 1000 == 0: print('Disturbing matrix\n',res[3]) print('Masked structure\n',res[4])
def train(T_est): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): #images, labels = cifar10.distorted_inputs() #images, labels = cifar10.noisy_distorted_inputs() images, labels, T_tru, T_mask = cifar10.noisy_distorted_inputs( return_T_flag=True) T_est = tf.constant(T_est, dtype=tf.float32) #### Prior and groudtruth T_est = tf.tile(tf.expand_dims(T_est, 0), [FLAGS.batch_size, 1, 1]) T_tru = tf.tile(tf.expand_dims(T_tru, 0), [FLAGS.batch_size, 1, 1]) T_mask = tf.tile(tf.expand_dims(T_mask, 0), [FLAGS.batch_size, 1, 1]) #### generator with tf.variable_scope('generator') as scope: normal = Normal(tf.zeros([1, 10]), tf.ones([1, 10])) epsilon = tf.to_float(normal.sample(FLAGS.batch_size)) net = slim.stack(epsilon, slim.fully_connected, [50, 50]) net = slim.fully_connected(net, cifar10.NUM_CLASSES * cifar10.NUM_CLASSES, activation_fn=None) net = tf.reshape(net, [-1, cifar10.NUM_CLASSES, cifar10.NUM_CLASSES]) S = tf.nn.softmax(net) # input to discriminator S_mask = tf.sigmoid((S - 0.05) / 0.005) #### discriminator def discriminator(input): with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE) as scope: input = slim.flatten(input) net = slim.fully_connected(input, 20, activation_fn=tf.nn.sigmoid) net = slim.fully_connected(net, 1, activation_fn=None) return net D_t = discriminator(T_mask) D_s = discriminator(S_mask) #### reconstructor dropout = tf.constant(0.75) logits = cifar10.inference(images, dropout, dropout_flag=True) preds = tf.nn.softmax(logits) preds_aug = tf.reshape( tf.matmul(tf.reshape(preds, [FLAGS.batch_size, 1, -1]), S), [FLAGS.batch_size, -1]) logits_aug = tf.log(tf.clip_by_value(preds_aug, 1e-8, 1.0 - 1e-8)) #### loss # R loss R_loss = cifar10.loss(logits_aug, labels) tf.summary.scalar('reconstructor loss', R_loss) # D loss D_loss = -tf.reduce_mean(D_t) + tf.reduce_mean(D_s) tf.summary.scalar('discriminator loss', D_loss) # G loss G_loss = R_loss - tf.reduce_mean(D_s) #G_loss = - tf.reduce_mean(D_s) tf.summary.scalar('generator loss', G_loss) # initialization of G S_logits = tf.log(tf.clip_by_value(S, 1e-8, 1.0 - 1e-8)) Initial_G_loss = -tf.reduce_mean( tf.reduce_sum(tf.reduce_sum(T_est * S_logits, axis=2), axis=1)) # variable list var_C = [] var_D = [] var_G = [] for item in tf.trainable_variables(): if "generator" in item.name: var_G.append(item) elif "discriminator" in item.name: var_D.append(item) else: var_C.append(item) #### optimizer # Build a Graph that trains the model with one batch of examples and # updates the model parameters. R_train_op, variable_averages, lr = cifar10.train( R_loss, global_step, var_C, return_variable_averages=True, return_lr=True) lr_DG = tf.constant(1e-5) D_train_op = tf.train.RMSPropOptimizer(learning_rate=lr_DG).minimize( D_loss, var_list=var_D) G_train_op = tf.train.RMSPropOptimizer(learning_rate=lr_DG).minimize( G_loss, var_list=var_G + var_C) #### optimizer for the initialization of the generator and the discriminator Initial_G_train_op = tf.train.RMSPropOptimizer( learning_rate=1e-4).minimize(Initial_G_loss, var_list=var_G) #### weight clamping for WGAN clip_D = [ var.assign(tf.clip_by_value(var, -0.01, 0.005)) for var in var_D ] class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._my_print_flag = False self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(R_loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) if self._my_print_flag: format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #### build scalffold for MonitoredTrainingSession to restore the variables you wish ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) variables_to_restore = variable_averages.variables_to_restore() #print(variables_to_restore) for var_name in variables_to_restore.keys(): if ('generator' in var_name) or ('discriminator' in var_name) or ( 'RMSProp' in var_name) or ('global_step' in var_name): del variables_to_restore[var_name] print(variables_to_restore) init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( ckpt.model_checkpoint_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) loggerHook = _LoggerHook() with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, scaffold=scaffold, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(R_loss), loggerHook ], save_checkpoint_secs=60, config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: #### pretrain the generator loggerHook._my_print_flag = False res = None for i in xrange(10000): res = mon_sess.run( [Initial_G_train_op, Initial_G_loss, T_est, S, lr, lr_DG]) if i % 1000 == 0: print('Step: %d\tGenerator loss: %.3f' % (i, res[1])) print('Pre-estimation', res[2][0]) print('Initialization', res[3][0]) #### iteratively train G and <D,R> loggerHook._my_print_flag = False step = 0 step_control = 0 lr_, lr_DG_ = res[-2], res[-1] while not mon_sess.should_stop(): # update the learning_rate of generator and discriminator to sync with the classifier if lr_DG_ >= lr_: # to avoid over-tuning the transition matrix due to the learning_rate decay lr_DG_ = lr_DG_ / 10.0 # do the adversarial game if step >= step_control: res = mon_sess.run( [G_train_op, G_loss, T_est, S, T_tru, S_mask, T_mask], feed_dict={lr_DG: lr_DG_}) g_loss = res[1] for i in xrange(5): _, d_loss = mon_sess.run([D_train_op, D_loss], feed_dict={lr_DG: lr_DG_}) # train the classifier _, r_loss, g_step, lr_, lr_DG_ = mon_sess.run( [R_train_op, R_loss, global_step, lr, lr_DG], feed_dict={lr_DG: lr_DG_}) if step >= step_control: print( 'Step: %d\tR_loss: %.3f\tD_loss: %.3f\tG_loss: %.3f' % (g_step, r_loss, d_loss, g_loss)) if (g_step % 2000 == 0) or (g_step == FLAGS.max_steps - 1): print('Pre-estimation', res[2][0]) print('Generated sample', res[3][0]) print('True transition', res[4][0]) print('Generated structure', res[5][0]) print('True structure', res[6][0]) else: print('Step: %d\tR_loss: %.3f' % (g_step, r_loss)) step = g_step