Example #1
0
def train(X,
          Y,
          args,
          increment=False,
          n_classes=10,
          exp=0,
          split=0,
          n_epochs=1):
    x_shape = [args.batch_size] + list(X.shape[1:])

    save_file = args.save_path + "exp_{}/model_split{}.ckpt".format(exp, split)

    with tf.Graph().as_default():
        dset = InputGenerator([None] + list(X.shape[1:]),
                              n_classes,
                              z_size=args.z_size,
                              batch_size=args.batch_size,
                              n_epochs=n_epochs)
        aae = AAE("train",
                  batch_size=args.batch_size,
                  n_epochs=n_epochs,
                  n_classes=n_classes,
                  z_size=args.z_size,
                  input_shape=x_shape)

        iterador = dset.create_train_generator()
        (x_input, y_input), (z_real, y_real) = iterador.get_next()

        # Estructura
        z_hat, y_hat = aae.encoder(x_input)
        x_recon = aae.decoder(z_hat, y_hat)

        dz_real = aae.discriminator_z(z_real)
        dz_fake = aae.discriminator_z(z_hat)
        dy_real = aae.discriminator_y(y_input)
        dy_fake = aae.discriminator_y(y_hat)

        _, y_tilde = aae.encoder(x_input, supervised=True)

        # Metricas
        acc, acc_op = tf.metrics.mean_per_class_accuracy(
            tf.argmax(y_input, -1), tf.argmax(y_tilde, -1), n_classes)
        mse, mse_op = tf.metrics.mean_squared_error(x_input, x_recon)

        # Costos
        ae_loss = tf.losses.log_loss(x_input, x_recon)
        clf_loss = tf.losses.softmax_cross_entropy(y_input, y_tilde)

        dz_real_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(dz_real),
                                                       dz_real)
        dz_fake_loss = tf.losses.sigmoid_cross_entropy(tf.zeros_like(dz_fake),
                                                       dz_fake)
        dz_loss = dz_real_loss + dz_fake_loss

        dy_real_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(dy_real),
                                                       dy_real)
        dy_fake_loss = tf.losses.sigmoid_cross_entropy(tf.zeros_like(dy_fake),
                                                       dy_fake)
        dy_loss = dy_real_loss + dy_fake_loss

        gz_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(dz_fake),
                                                  dz_fake)
        gy_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(dy_fake),
                                                  dy_fake)
        gen_loss = gz_loss + gy_loss

        # Training ops
        all_vars = tf.trainable_variables()
        dz_vars = [var for var in all_vars if "Discriminator_Z" in var.name]
        dy_vars = [var for var in all_vars if "Discriminator_Y" in var.name]
        enc_vars = [var for var in all_vars if "Encoder" in var.name]
        ae_vars = enc_vars + [var for var in all_vars if "Decoder" in var.name]

        if increment:
            increment_vars = [
                var for var in tf.global_variables() if "Y_" not in var.name
            ]
            increment_vars = [
                var for var in increment_vars
                if "Discriminator" not in var.name
            ]
            init_vars = [
                var for var in tf.global_variables() if "Y_" in var.name
            ]
            init_vars += [
                var for var in tf.global_variables()
                if "Discriminator" in var.name
            ]
        else:
            increment_vars = None
            init_vars = None

        step_tensor = tf.Variable(0, trainable=False, name="Step")
        learning_rate = 0.001
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.control_dependencies(update_ops):
            ae_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                ae_loss, var_list=ae_vars, global_step=step_tensor)
            dz_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                dz_loss, var_list=dz_vars)
            dy_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                dy_loss, var_list=dy_vars)
            gen_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                gen_loss, var_list=enc_vars)
            clf_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                clf_loss, var_list=enc_vars)
            train_ops = tf.group([ae_opt, dz_opt, dy_opt, gen_opt, clf_opt])

        if increment:
            saver = tf.train.Saver(increment_vars)
        ckpt_saver = tf.train.Saver()

        with tf.Session() as sess:
            if increment:
                sess.run(tf.global_variables_initializer())
                saver.restore(
                    sess,
                    tf.train.latest_checkpoint(args.save_path +
                                               "exp_{}/".format(exp)))
                # sess.run(tf.variables_initializer(init_vars))
            else:
                sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            #Cargar los datasets
            sess.run(iterador.initializer,
                     feed_dict={
                         dset.x_input: X,
                         dset.y_input: Y
                     })
            n_steps = (len(X) // args.batch_size) * n_epochs
            # Operacion de entrenamiento:
            with tqdm(desc="Train", total=n_steps, unit="Steps",
                      miniters=10) as pbar:
                try:
                    while True:
                        _, step, accuracy, msqer, _, _ = sess.run(
                            [train_ops, step_tensor, acc, mse, acc_op, mse_op])
                        pbar.update()
                        if step % 10 == 0:
                            pbar.set_postfix(Accuracy=accuracy,
                                             MSE=msqer,
                                             refresh=False)
                except tf.errors.OutOfRangeError:
                    pass
            ckpt_saver.save(sess, save_path=save_file)
        print("Done!")
Example #2
0
def train(X,
          Y,
          args,
          increment=False,
          n_classes=10,
          exp=0,
          split=0,
          n_epochs=1):
    x_shape = [args.batch_size] + list(X.shape[1:])
    y_shape = [args.batch_size] + list(Y.shape[1:])
    save_file = args.save_path + "exp_{}/model_split{}.ckpt".format(exp, split)
    monitor_path = args.monitor + "exp_{}/split_{}".format(exp, split)

    with tf.Graph().as_default():

        dset = InputGenerator([None] + list(X.shape[1:]),
                              n_classes=n_classes,
                              args.z_size,
                              batch_size=args.batch_size,
                              n_epochs=n_epochs)
        aae = AAE("train",
                  batch_size=args.batch_size,
                  n_epochs=n_epochs,
                  n_classes=n_classes,
                  z_size=args.z_size,
                  input_shape=x_shape)

        iterador = dset.create_train_generator()
        (x_input, y_input), (z_real, y_real) = iterador.get_next()
        # Estructura
        z_hat, y_hat = aae.encoder(x_input)

        pz = tf.sigmoid(z_hat + 1e-8)
        entropia = tf.reduce_mean(-tf.reduce_sum(pz * tf.log(pz), 1))
        #entropia = tf.reduce_mean(0.5*tf.norm(z_hat - z_real, ord=1, axis=1))
        x_recon = aae.decoder(z_hat, y_hat)

        dz_real = aae.discriminator_z(z_real)
        dz_fake = aae.discriminator_z(z_hat)
        dy_real = aae.discriminator_y(y_real)
        dy_fake = aae.discriminator_y(y_hat)

        _, y_tilde = aae.encoder(x_input, supervised=True)

        # Metricas
        acc, acc_op = tf.metrics.mean_per_class_accuracy(
            tf.argmax(y_input, -1), tf.argmax(y_tilde, -1), n_classes)
        mse, mse_op = tf.metrics.mean_squared_error(x_input, x_recon)

        # Costos
        ae_loss = tf.losses.log_loss(
            x_input, x_recon
        )  # tf.reduce_mean(aae.binary_crossentropy(x_input, x_recon))
        clf_loss = tf.losses.softmax_cross_entropy(
            y_input, y_tilde
        )  #tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_input, logits=y_tilde))

        dz_real_loss = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(dz_real), dz_real
        )  #tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(dz_real), logits=dz_real))
        dz_fake_loss = tf.losses.sigmoid_cross_entropy(
            tf.zeros_like(dz_fake), dz_fake
        )  #tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(dz_fake), logits=dz_fake))
        dz_loss = dz_real_loss + dz_fake_loss

        dy_real_loss = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(dy_real), dy_real
        )  #tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(dy_real), logits=dy_real))
        dy_fake_loss = tf.losses.sigmoid_cross_entropy(
            tf.zeros_like(dy_fake), dy_fake
        )  #tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(dy_fake), logits=dy_fake))
        dy_loss = dy_real_loss + dy_fake_loss

        gz_loss = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(dz_fake), dz_fake
        )  #tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(dz_fake), logits=dz_fake))
        gy_loss = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(dy_fake), dy_fake
        )  #tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(dy_fake), logits=dy_fake))
        gen_loss = gz_loss + gy_loss

        # Training ops
        all_vars = tf.trainable_variables()
        dz_vars = [var for var in all_vars if "Discriminator_Z" in var.name]
        dy_vars = [var for var in all_vars if "Discriminator_Y" in var.name]
        enc_vars = [var for var in all_vars if "Encoder" in var.name]

        if increment:
            increment_vars = [
                var for var in tf.global_variables() if "Y_" not in var.name
            ]
            init_vars = [
                var for var in tf.global_variables() if "Y_" in var.name
            ]
        else:
            increment_vars = None
            init_vars = None

        step_tensor = tf.Variable(0, trainable=False, name="Step")
        #learning_rate = tf.train.polynomial_decay(0.005, global_step=step_tensor, decay_steps=20000, end_learning_rate=0.000001, power=2, name="Learning_rate")
        learning_rate = 0.001
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.control_dependencies(update_ops):
            ae_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                ae_loss, global_step=step_tensor)
            dz_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                dz_loss, var_list=dz_vars)
            dy_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                dy_loss, var_list=dy_vars)
            gen_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                gen_loss, var_list=enc_vars)
            clf_opt = tf.train.AdamOptimizer(learning_rate).minimize(
                clf_loss, var_list=enc_vars)
            train_ops = tf.group([ae_opt, dz_opt, dy_opt, gen_opt, clf_opt])

        # summaries
        tf.summary.scalar("Losses/AE_loss", ae_loss)
        tf.summary.scalar("Losses/Dis_Y_loss", dy_loss)
        tf.summary.scalar("Losses/Dis_Z_loss", dz_loss)
        tf.summary.scalar("Losses/Gen_loss", gen_loss)
        tf.summary.scalar("Losses/Clf_loss", clf_loss)
        tf.summary.scalar("Metrics/Accuracy", acc)
        tf.summary.scalar("Metrics/MSE", mse)
        tf.summary.scalar("Metrics/Entropy", entropia)
        #tf.summary.scalar("Metrics/LearningRate",learning_rate)
        tf.summary.histogram("Z_pred", z_hat)
        tf.summary.histogram("Z_real", z_real)
        tf.summary.image("X_Real", x_input, 10)
        tf.summary.image("X_Recon", x_recon, 10)
        summary_op = tf.summary.merge_all()

        if increment:
            saver = tf.train.Saver(increment_vars)
        ckpt_saver = tf.train.Saver()

        with tf.Session() as sess:
            if increment:
                sess.run(tf.global_variables_initializer())
                saver.restore(
                    sess,
                    tf.train.latest_checkpoint(args.save_path +
                                               "exp_{}/".format(exp)))
                # sess.run(tf.variables_initializer(init_vars))
            else:
                sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            summary_writer = tf.summary.FileWriter(monitor_path)
            #Cargar los datasets
            sess.run(iterador.initializer,
                     feed_dict={
                         dset.x_input: X,
                         dset.y_input: Y
                     })
            n_steps = (len(X) // args.batch_size) * n_epochs
            # Operacion de entrenamiento:
            with tqdm(desc="Train", total=n_steps, unit="Steps",
                      miniters=10) as pbar:
                try:
                    while True:
                        _, step, accuracy, msqer, _, _, summary = sess.run([
                            train_ops, step_tensor, acc, mse, acc_op, mse_op,
                            summary_op
                        ])
                        summary_writer.add_summary(summary, step)
                        pbar.update()
                        if step % 10 == 0:
                            pbar.set_postfix(Accuracy=accuracy,
                                             MSE=msqer,
                                             refresh=False)

                except tf.errors.OutOfRangeError:
                    pass
            ckpt_saver.save(sess, save_path=save_file)
        print("Done!")