Ejemplo n.º 1
0
def main(max_epoch, data_num):
    sc = init_nncontext()

    # get data, pre-process and create TFDataset
    def get_data_rdd(dataset):
        (images_data,
         labels_data) = mnist.read_data_sets("/tmp/mnist", dataset)
        image_rdd = sc.parallelize(images_data[:data_num])
        labels_rdd = sc.parallelize(labels_data[:data_num])
        rdd = image_rdd.zip(labels_rdd) \
            .map(lambda rec_tuple: [normalizer(rec_tuple[0], mnist.TRAIN_MEAN, mnist.TRAIN_STD),
                                    np.array(rec_tuple[1])])
        return rdd

    training_rdd = get_data_rdd("train")
    testing_rdd = get_data_rdd("test")
    dataset = TFDataset.from_rdd(training_rdd,
                                 names=["features", "labels"],
                                 shapes=[[28, 28, 1], []],
                                 types=[tf.float32, tf.int32],
                                 batch_size=280,
                                 val_rdd=testing_rdd)

    # construct the model from TFDataset
    images, labels = dataset.tensors

    with slim.arg_scope(lenet.lenet_arg_scope()):
        logits, end_points = lenet.lenet(images,
                                         num_classes=10,
                                         is_training=True)

    loss = tf.reduce_mean(
        tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels))

    # create a optimizer
    optimizer = TFOptimizer(loss,
                            Adam(1e-3),
                            val_outputs=[logits],
                            val_labels=[labels],
                            val_method=Top1Accuracy())
    optimizer.set_train_summary(TrainSummary("/tmp/az_lenet", "lenet"))
    optimizer.set_val_summary(ValidationSummary("/tmp/az_lenet", "lenet"))
    # kick off training
    optimizer.optimize(end_trigger=MaxEpoch(max_epoch))

    saver = tf.train.Saver()
    saver.save(optimizer.sess, "/tmp/lenet/model")
Ejemplo n.º 2
0
    def train(self, dataset, end_trigger):

        with tf.Graph().as_default() as g:

            generator_inputs = dataset.tensors[0]
            real_data = dataset.tensors[1]

            counter = tf.Variable(0, dtype=tf.int32)

            period = self._discriminator_steps + self._generator_steps

            is_discriminator_phase = tf.less(tf.mod(counter, period),
                                             self._discriminator_steps)

            with tf.variable_scope("generator"):
                gen_data = self._call_fn_maybe_with_counter(
                    self._generator_fn, counter, generator_inputs)

            with tf.variable_scope("discriminator"):
                fake_d_outputs = self._call_fn_maybe_with_counter(
                    self._discriminator_fn, counter, gen_data,
                    generator_inputs)

            with tf.variable_scope("discriminator", reuse=True):
                real_d_outputs = self._call_fn_maybe_with_counter(
                    self._discriminator_fn, counter, real_data,
                    generator_inputs)

            with tf.name_scope("generator_loss"):
                generator_loss = self._call_fn_maybe_with_counter(
                    self._generator_loss_fn, counter, fake_d_outputs)

            with tf.name_scope("discriminator_loss"):
                discriminator_loss = self._call_fn_maybe_with_counter(
                    self._discriminator_loss_fn, counter, real_d_outputs,
                    fake_d_outputs)

            generator_variables = tf.trainable_variables("generator")
            generator_grads = tf.gradients(generator_loss, generator_variables)
            discriminator_variables = tf.trainable_variables("discriminator")
            discriminator_grads = tf.gradients(discriminator_loss,
                                               discriminator_variables)

            variables = generator_variables + discriminator_variables

            def true_fn():
                return [tf.zeros_like(grad) for grad in generator_grads]

            def false_fn():
                return generator_grads

            g_grads = tf.cond(is_discriminator_phase,
                              true_fn=true_fn,
                              false_fn=false_fn)
            d_grads = tf.cond(
                is_discriminator_phase, lambda: discriminator_grads,
                lambda: [tf.zeros_like(grad) for grad in discriminator_grads])
            loss = tf.cond(is_discriminator_phase, lambda: discriminator_loss,
                           lambda: generator_loss)

            grads = g_grads + d_grads

            with tf.control_dependencies(grads):
                increase_counter = tf.assign_add(counter, 1)

            g_param_size = sum([np.product(g.shape) for g in g_grads])
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                tf_model = TFModel.create_for_unfreeze(
                    loss,
                    sess,
                    inputs=dataset._original_tensors,
                    grads=grads,
                    variables=variables,
                    graph=g,
                    tensors_with_value=None,
                    session_config=None,
                    metrics=None,
                    updates=[increase_counter],
                    model_dir=self.checkpoint_path)

                optimizer = TFOptimizer(tf_model,
                                        GanOptimMethod(
                                            self._discriminator_optim_method,
                                            self._generator_optim_method,
                                            g_param_size.value,
                                            self._discriminator_steps,
                                            self._generator_steps),
                                        sess=sess,
                                        dataset=dataset,
                                        model_dir=self.checkpoint_path)
                optimizer.optimize(end_trigger)
                steps = sess.run(counter)
                saver = tf.train.Saver()
                saver.save(optimizer.sess,
                           self.checkpoint_path,
                           global_step=steps)