def train(infer_z, noisy_y, C, img_label):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #indices, images, labels = cifar10.distorted_inputs()
            indices, images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(
                return_T_flag=True, noise_ratio=FLAGS.noise_ratio)
            indices = indices[:,
                              0]  # rank 2 --> rank 1, i.e., (batch_size,1) --> (batch_size,)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        is_training = tf.placeholder(tf.bool, shape=(), name='bn_flag')
        logits = cifar10.inference(images, training=is_training)
        preds = tf.nn.softmax(logits)

        # approximate Gibbs sampling
        T = tf.placeholder(tf.float32,
                           shape=[cifar10.NUM_CLASSES, cifar10.NUM_CLASSES],
                           name='transition')
        if FLAGS.groudtruth:
            unnorm_probs = preds * tf.gather(tf.transpose(T_tru, [1, 0]),
                                             labels)
        else:
            unnorm_probs = preds * tf.gather(tf.transpose(T, [1, 0]), labels)

        probs = unnorm_probs / tf.reduce_sum(
            unnorm_probs, axis=1, keepdims=True)
        sampler = OneHotCategorical(probs=probs)
        labels_ = tf.stop_gradient(tf.argmax(sampler.sample(), axis=1))

        loss = cifar10.loss(logits, labels_)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Calculate prediction
        # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
        # if you only want to inspect acc of each batch, just sess run acc_op[0]
        acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1))
        tf.summary.scalar('training accuracy', acc_op[0])

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        variables_to_restore = []
        #variables_to_restore += [var for var in tf.trainable_variables() if 'dense' not in var.name] # if final layer is not included
        variables_to_restore += tf.trainable_variables(
        )  # if final layer is included
        variables_to_restore += [
            g for g in tf.global_variables()
            if 'moving_mean' in g.name or 'moving_variance' in g.name
        ]
        for var in variables_to_restore:
            print(var.name)
        #variables_to_restore = []
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    tf.get_collection('losses')[0])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            ## initialize some params
            alpha = 1.0
            C_init = C.copy()
            trans_init = (C + alpha) / np.sum(C + alpha, axis=1, keepdims=True)

            ## running setting
            warming_up_step = 20000
            step = 0
            freq_trans = 200

            ### warming up transition
            with open('T_%.2f.pkl' % FLAGS.noise_ratio) as f:
                data = pickle.load(f)
            trans_warming = data[
                2]  # trans_init or np.eye(cifar10.NUM_CLASSES)

            ## record and run
            exemplars = []
            label_trace_exemplars = []
            infer_z_probs = dict()
            trans_before_after_trace = []
            while not mon_sess.should_stop():
                if step % freq_trans == 0:  # update transition matrix in each n steps
                    trans = (C + alpha) / np.sum(
                        C + alpha, axis=1, keepdims=True)

                if step < warming_up_step:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, indices, labels,
                        labels_, probs
                    ],
                                       feed_dict={
                                           is_training: True,
                                           T: trans_warming
                                       })
                else:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, indices, labels,
                        labels_, probs
                    ],
                                       feed_dict={
                                           is_training: True,
                                           T: trans
                                       })

                #print(res[3].shape)
                trans_before = (C + alpha) / np.sum(
                    C + alpha, axis=1, keepdims=True)
                C_before = C.copy()
                for i in xrange(res[3].shape[0]):
                    ind = res[3][i]
                    #print(noisy_y[ind],res[4][i])
                    assert noisy_y[ind] == res[4][i]
                    C[infer_z[ind]][noisy_y[ind]] -= 1
                    assert C[infer_z[ind]][noisy_y[ind]] >= 0
                    infer_z[ind] = res[5][i]
                    infer_z_probs[ind] = res[6][i]
                    C[infer_z[ind]][noisy_y[ind]] += 1
                    #print(res[4][i],res[5][i])

                trans_after = (C + alpha) / np.sum(
                    C + alpha, axis=1, keepdims=True)
                C_after = C.copy()
                trans_gap = np.sum(np.absolute(trans_after - trans_before))
                rou = np.sum(C_after - C_before, axis=-1) / np.sum(
                    C_before + alpha, axis=-1)
                rou_ = np.sum(np.absolute(C_after - C_before),
                              axis=-1) / np.sum(C_before + alpha, axis=-1)
                trans_bound = np.sum((np.absolute(rou) + rou_) / (1 + rou))
                trans_before_after_trace.append([step, trans_gap, trans_bound])
                #print(trans_gap, trans_bound)

                step = res[2]
                if step % 1000 == 0:
                    print('Counting matrix\n', C)
                    print('Counting matrix\n', C_init)
                    print('Transition matrix\n', trans)
                    print('Transition matrix\n', trans_init)

                if step % 5000 == 0:
                    exemplars.append([
                        infer_z.copy().keys(),
                        infer_z.copy().values(),
                        C.copy()
                    ])

                if step % FLAGS.max_steps_per_epoch == 0:
                    r_n = 0
                    all_n = 0
                    for key in infer_z.keys():
                        if infer_z[key] == img_label[key]:
                            r_n += 1
                        all_n += 1
                    acc = r_n / all_n
                    #print('accuracy: %.2f'%acc)
                    label_trace_exemplars.append(
                        [infer_z.copy(),
                         infer_z_probs.copy(), acc])

            if not FLAGS.groudtruth:
                with open('varC_learnt_%.2f.pkl' % FLAGS.noise_ratio,
                          'w') as w:
                    pickle.dump(exemplars, w)
            else:
                with open('varC_learnt_%.2f_tru.pkl' % FLAGS.noise_ratio,
                          'w') as w:
                    pickle.dump(exemplars, w)

            if FLAGS.labeltrace:
                with open('varC_label_trace_%.2f.pkl' % FLAGS.noise_ratio,
                          'w') as w:
                    pickle.dump([label_trace_exemplars, img_label], w)

            with open('varC_transvar_trace_%.2f.pkl' % FLAGS.noise_ratio,
                      'w') as w:
                pickle.dump(trans_before_after_trace, w)
def train(T_est, T_inv_est):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        T_est = tf.constant(T_est)
        T_inv_est = tf.constant(T_inv_est)

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #images, labels = cifar10.distorted_inputs()
            images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(
                return_T_flag=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        dropout = tf.constant(0.75)
        logits = cifar10.inference(images, dropout, dropout_flag=True)

        # Calculate loss.
        #loss = loss_forward(logits, labels, T_est)
        loss = loss_forward(logits, labels, T_tru)
        #loss = loss_backward(logits, labels, T_inv_est)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op, variable_averages = cifar10.train(
            loss, global_step, return_variable_averages=True)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        variables_to_restore = variable_averages.variables_to_restore()
        #print(variables_to_restore)
        for var_name in variables_to_restore.keys():
            if ('logits_T' in var_name) or ('global_step' in var_name):
                del variables_to_restore[var_name]
        #print(variables_to_restore)

        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            while not mon_sess.should_stop():
                res = mon_sess.run([train_op, global_step, T_tru, T_mask_tru])
                if res[1] % 1000 == 0:
                    print('Disturbing matrix\n', res[2])
                    print('Masked structure\n', res[3])
def train(T_fixed, T_init):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #indices, images, labels = cifar10.distorted_inputs()
            indices, images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(
                return_T_flag=True, noise_ratio=FLAGS.noise_ratio)
            indices = indices[:, 0]

        # Build a Graph that computes the logits predictions from the
        # inference model.
        is_training = tf.placeholder(tf.bool, shape=(), name='bn_flag')
        logits = cifar10.inference(images, training=is_training)
        preds = tf.nn.softmax(logits)

        # fixed adaption layer
        fixed_adaption_layer = tf.cast(tf.constant(T_fixed), tf.float32)

        # adaption layer
        logits_T = tf.get_variable(
            'logits_T',
            shape=[cifar10.NUM_CLASSES, cifar10.NUM_CLASSES],
            initializer=tf.constant_initializer(np.log(T_init + 1e-8)))
        adaption_layer = tf.nn.softmax(logits_T)

        # label adaption
        is_use = tf.placeholder(tf.bool, shape=(), name='warming_up_flag')
        adaption = tf.cond(is_use, lambda: fixed_adaption_layer,
                           lambda: adaption_layer)
        preds_aug = tf.clip_by_value(tf.matmul(preds, adaption), 1e-8,
                                     1.0 - 1e-8)
        logits_aug = tf.log(preds_aug)

        # Calculate loss.
        loss = cifar10.loss(logits_aug, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Calculate prediction
        # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
        # if you only want to inspect acc of each batch, just sess run acc_op[0]
        acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1))
        tf.summary.scalar('training accuracy', acc_op[0])

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        variables_to_restore = []
        #variables_to_restore += [var for var in tf.trainable_variables() if ('dense' not in var.name and 'logits_T' not in var.name)]
        variables_to_restore += [
            var for var in tf.trainable_variables()
            if 'logits_T' not in var.name
        ]
        variables_to_restore += [
            g for g in tf.global_variables()
            if 'moving_mean' in g.name or 'moving_variance' in g.name
        ]
        for var in variables_to_restore:
            print(var.name)
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    tf.get_collection('losses')[0])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            warming_up_step = 32000
            step = 0
            varT_rec = []
            varT_trans_trace = []
            while not mon_sess.should_stop():
                if step < warming_up_step:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, fixed_adaption_layer,
                        T_tru, T_mask_tru
                    ],
                                       feed_dict={
                                           is_training: True,
                                           is_use: True
                                       })
                else:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, adaption_layer, T_tru,
                        T_mask_tru
                    ],
                                       feed_dict={
                                           is_training: True,
                                           is_use: False
                                       })
                step = res[2]

                if step % 5000 == 0:
                    varT_rec.append(res[3])

                if step == warming_up_step:
                    trans_before = res[3].copy()
                if step > warming_up_step:
                    trans_after = res[3].copy()
                    trans_gap = np.sum(np.absolute(trans_before - trans_after))
                    varT_trans_trace.append([step, trans_gap])

        with open('varT_learnt_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w:
            pickle.dump(varT_rec, w)

        with open('varT_transvar_trace_%.2f.pkl' % FLAGS.noise_ratio,
                  'w') as w:
            pickle.dump(varT_trans_trace, w)
Exemple #4
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #images, labels = cifar10.distorted_inputs()
            images, labels, T, T_mask = cifar10.noisy_distorted_inputs(
                return_T_flag=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        dropout = tf.placeholder(tf.float32, name='dropout_rate')
        logits = cifar10.inference(images, dropout, dropout_flag=True)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            while not mon_sess.should_stop():
                #mon_sess.run(train_op,feed_dict={dropout:0.75})
                res = mon_sess.run([train_op, global_step, T, T_mask],
                                   feed_dict={dropout: 0.75})
                if res[1] % 1000 == 0:
                    print('Disturbing matrix\n', res[2])
                    print('Masked structure\n', res[3])
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #indices, images, labels = cifar10.distorted_inputs()
            indices, images, labels, T, T_mask = cifar10.noisy_distorted_inputs(
                noise_ratio=FLAGS.noise_ratio, return_T_flag=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        is_training = tf.placeholder(tf.bool, shape=(), name="bn_flag")
        logits = cifar10.inference(images, training=is_training)

        # Calculate loss.
        # loss = cifar10.loss(logits, labels)
        # Calculate the average cross entropy loss across the batch.
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(labels, tf.int64),
            logits=logits,
            name='cross_entropy_per_example')
        loss1 = tf.reduce_mean(cross_entropy, name='cross_entropy')
        tf.add_to_collection('losses', loss1)

        # perceptual loss
        preds = tf.nn.softmax(logits)
        preds = tf.clip_by_value(preds, 1e-8, 1 - 1e-8)
        loss2 = tf.reduce_mean(-tf.reduce_sum(preds * tf.log(preds), axis=-1),
                               name='perceptual_certainty_soft')
        #loss2 = tf.reduce_mean(-tf.reduce_sum(tf.stop_gradient(tf.to_float(tf.one_hot(tf.argmax(preds,axis=-1),depth=cifar10.NUM_CLASSES,axis=-1)))*tf.log(preds),axis=-1), name='perceptual_certainty_hard')
        tf.add_to_collection('losses', loss2)

        # l2 loss
        l2_loss = tf.add_n([
            cifar10.WEIGHT_DECAY * tf.nn.l2_loss(tf.cast(v, tf.float32))
            for v in tf.trainable_variables()
            if 'batch_normalization' not in v.name
        ],
                           name='l2_loss')
        tf.add_to_collection('losses', l2_loss)

        # weighted loss
        alpha = tf.placeholder(tf.float32, shape=(), name='perceptual_weight')
        _LoggerHook_loss = alpha * loss1 + (1 - alpha) * loss2
        loss = alpha * loss1 + (1 - alpha) * loss2 + l2_loss

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Calculate prediction
        # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
        # if you only want to inspect acc of each batch, just sess run acc_op[0]
        acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1))
        tf.summary.scalar('training accuracy', acc_op[0])

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    tf.get_collection('losses')[0])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            while not mon_sess.should_stop():
                #mon_sess.run([train_op,acc_op,global_step],feed_dict={is_training:True, alpha: 0.5})
                res = mon_sess.run([train_op, acc_op, global_step, T, T_mask],
                                   feed_dict={
                                       is_training: True,
                                       alpha: 0.5
                                   })
                if res[2] % 1000 == 0:
                    print('Disturbing matrix\n', res[3])
                    print('Masked structure\n', res[4])
Exemple #6
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #indices, images, labels = cifar10.distorted_inputs()
            indices, images, labels, T, T_mask = cifar10.noisy_distorted_inputs(
                noise_ratio=FLAGS.noise_ratio, return_T_flag=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        is_training = tf.placeholder(tf.bool, shape=(), name="bn_flag")
        logits = cifar10.inference(images, training=is_training)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Calculate prediction
        # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
        # if you only want to inspect acc of each batch, just sess run acc_op[0]
        acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1))
        tf.summary.scalar('training accuracy', acc_op[0])

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    tf.get_collection('losses')[0])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            while not mon_sess.should_stop():
                #mon_sess.run([train_op,acc_op,global_step],feed_dict={is_training:True})
                res = mon_sess.run([train_op, acc_op, global_step, T, T_mask],
                                   feed_dict={is_training: True})
                if res[2] % 1000 == 0:
                    print('Disturbing matrix\n', res[3])
                    print('Masked structure\n', res[4])
def train(T_est,T_inv_est):
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      #indices, images, labels = cifar10.distorted_inputs()
      indices,images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(return_T_flag=True,noise_ratio=FLAGS.noise_ratio)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    is_training = tf.placeholder(tf.bool,shape=(),name='bn_flag')
    logits = cifar10.inference(images,training=is_training)
 
    T_est = tf.cast(tf.constant(T_est),tf.float32)
    T_inv_est = tf.cast(tf.constant(T_inv_est),tf.float32)
    # Calculate loss.
    if FLAGS.groudtruth:
      loss = loss_forward(logits, labels, T_tru)
    else:
      loss = loss_forward(logits, labels, T_est)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Calculate prediction
    # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
    # if you only want to inspect acc of each batch, just sess run acc_op[0]
    acc_op = tf.metrics.accuracy(labels, tf.argmax(logits,axis=1))
    tf.summary.scalar('training accuracy', acc_op[0])

    #### build scalffold for MonitoredTrainingSession to restore the variables you wish
    variables_to_restore = []
    #variables_to_restore += [var for var in tf.trainable_variables() if 'dense' not in var.name] # if final layer is not included
    variables_to_restore += tf.trainable_variables() # if final layer is included
    variables_to_restore += [g for g in tf.global_variables() if 'moving_mean' in g.name or 'moving_variance' in g.name]
    for var in variables_to_restore:
      print(var.name)
    ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
    init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
         ckpt.model_checkpoint_path, variables_to_restore)
    def InitAssignFn(scaffold,sess):
       sess.run(init_assign_op, init_feed_dict)

    scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(tf.get_collection('losses')[0])  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        scaffold = scaffold,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        save_checkpoint_secs=60,
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess:
      while not mon_sess.should_stop():
        res = mon_sess.run([train_op,acc_op,global_step,T_tru,T_mask_tru],feed_dict={is_training:True})
        if res[2] % 1000 == 0:
          print('Disturbing matrix\n',res[3])
          print('Masked structure\n',res[4])
def train(T_est):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #images, labels = cifar10.distorted_inputs()
            #images, labels = cifar10.noisy_distorted_inputs()
            images, labels, T_tru, T_mask = cifar10.noisy_distorted_inputs(
                return_T_flag=True)

        T_est = tf.constant(T_est, dtype=tf.float32)

        #### Prior and groudtruth
        T_est = tf.tile(tf.expand_dims(T_est, 0), [FLAGS.batch_size, 1, 1])
        T_tru = tf.tile(tf.expand_dims(T_tru, 0), [FLAGS.batch_size, 1, 1])
        T_mask = tf.tile(tf.expand_dims(T_mask, 0), [FLAGS.batch_size, 1, 1])

        #### generator
        with tf.variable_scope('generator') as scope:
            normal = Normal(tf.zeros([1, 10]), tf.ones([1, 10]))
            epsilon = tf.to_float(normal.sample(FLAGS.batch_size))
            net = slim.stack(epsilon, slim.fully_connected, [50, 50])
            net = slim.fully_connected(net,
                                       cifar10.NUM_CLASSES *
                                       cifar10.NUM_CLASSES,
                                       activation_fn=None)
            net = tf.reshape(net,
                             [-1, cifar10.NUM_CLASSES, cifar10.NUM_CLASSES])
        S = tf.nn.softmax(net)

        # input to discriminator
        S_mask = tf.sigmoid((S - 0.05) / 0.005)

        #### discriminator
        def discriminator(input):
            with tf.variable_scope('discriminator',
                                   reuse=tf.AUTO_REUSE) as scope:
                input = slim.flatten(input)
                net = slim.fully_connected(input,
                                           20,
                                           activation_fn=tf.nn.sigmoid)
                net = slim.fully_connected(net, 1, activation_fn=None)
            return net

        D_t = discriminator(T_mask)
        D_s = discriminator(S_mask)

        #### reconstructor
        dropout = tf.constant(0.75)
        logits = cifar10.inference(images, dropout, dropout_flag=True)
        preds = tf.nn.softmax(logits)
        preds_aug = tf.reshape(
            tf.matmul(tf.reshape(preds, [FLAGS.batch_size, 1, -1]), S),
            [FLAGS.batch_size, -1])
        logits_aug = tf.log(tf.clip_by_value(preds_aug, 1e-8, 1.0 - 1e-8))

        #### loss
        # R loss
        R_loss = cifar10.loss(logits_aug, labels)
        tf.summary.scalar('reconstructor loss', R_loss)

        # D loss
        D_loss = -tf.reduce_mean(D_t) + tf.reduce_mean(D_s)
        tf.summary.scalar('discriminator loss', D_loss)

        # G loss
        G_loss = R_loss - tf.reduce_mean(D_s)
        #G_loss = - tf.reduce_mean(D_s)
        tf.summary.scalar('generator loss', G_loss)

        # initialization of G
        S_logits = tf.log(tf.clip_by_value(S, 1e-8, 1.0 - 1e-8))
        Initial_G_loss = -tf.reduce_mean(
            tf.reduce_sum(tf.reduce_sum(T_est * S_logits, axis=2), axis=1))

        # variable list
        var_C = []
        var_D = []
        var_G = []
        for item in tf.trainable_variables():
            if "generator" in item.name:
                var_G.append(item)
            elif "discriminator" in item.name:
                var_D.append(item)
            else:
                var_C.append(item)

        #### optimizer
        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        R_train_op, variable_averages, lr = cifar10.train(
            R_loss,
            global_step,
            var_C,
            return_variable_averages=True,
            return_lr=True)
        lr_DG = tf.constant(1e-5)
        D_train_op = tf.train.RMSPropOptimizer(learning_rate=lr_DG).minimize(
            D_loss, var_list=var_D)
        G_train_op = tf.train.RMSPropOptimizer(learning_rate=lr_DG).minimize(
            G_loss, var_list=var_G + var_C)

        #### optimizer for the initialization of the generator and the discriminator
        Initial_G_train_op = tf.train.RMSPropOptimizer(
            learning_rate=1e-4).minimize(Initial_G_loss, var_list=var_G)

        #### weight clamping for WGAN
        clip_D = [
            var.assign(tf.clip_by_value(var, -0.01, 0.005)) for var in var_D
        ]

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._my_print_flag = False
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(R_loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    if self._my_print_flag:
                        format_str = (
                            '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                            'sec/batch)')
                        print(format_str %
                              (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        variables_to_restore = variable_averages.variables_to_restore()
        #print(variables_to_restore)
        for var_name in variables_to_restore.keys():
            if ('generator' in var_name) or ('discriminator' in var_name) or (
                    'RMSProp' in var_name) or ('global_step' in var_name):
                del variables_to_restore[var_name]
        print(variables_to_restore)

        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
        loggerHook = _LoggerHook()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(R_loss), loggerHook
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:

            #### pretrain the generator
            loggerHook._my_print_flag = False
            res = None
            for i in xrange(10000):
                res = mon_sess.run(
                    [Initial_G_train_op, Initial_G_loss, T_est, S, lr, lr_DG])
                if i % 1000 == 0:
                    print('Step: %d\tGenerator loss: %.3f' % (i, res[1]))
                    print('Pre-estimation', res[2][0])
                    print('Initialization', res[3][0])

            #### iteratively train G and <D,R>
            loggerHook._my_print_flag = False
            step = 0
            step_control = 0
            lr_, lr_DG_ = res[-2], res[-1]
            while not mon_sess.should_stop():
                # update the learning_rate of generator and discriminator to sync with the classifier
                if lr_DG_ >= lr_:  # to avoid over-tuning the transition matrix due to the learning_rate decay
                    lr_DG_ = lr_DG_ / 10.0
                # do the adversarial game
                if step >= step_control:
                    res = mon_sess.run(
                        [G_train_op, G_loss, T_est, S, T_tru, S_mask, T_mask],
                        feed_dict={lr_DG: lr_DG_})
                    g_loss = res[1]

                    for i in xrange(5):
                        _, d_loss = mon_sess.run([D_train_op, D_loss],
                                                 feed_dict={lr_DG: lr_DG_})

                # train the classifier
                _, r_loss, g_step, lr_, lr_DG_ = mon_sess.run(
                    [R_train_op, R_loss, global_step, lr, lr_DG],
                    feed_dict={lr_DG: lr_DG_})

                if step >= step_control:
                    print(
                        'Step: %d\tR_loss: %.3f\tD_loss: %.3f\tG_loss: %.3f' %
                        (g_step, r_loss, d_loss, g_loss))

                    if (g_step % 2000 == 0) or (g_step == FLAGS.max_steps - 1):
                        print('Pre-estimation', res[2][0])
                        print('Generated sample', res[3][0])
                        print('True transition', res[4][0])
                        print('Generated structure', res[5][0])
                        print('True structure', res[6][0])
                else:
                    print('Step: %d\tR_loss: %.3f' % (g_step, r_loss))

                step = g_step