Beispiel #1
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(0)):
            latent_gan = PCL2PCLGAN(para_config_gan, para_config_ae)
            #print_trainable_vars()
            G_loss, G_tofool_loss, reconstr_loss, D_loss, D_fake_loss, D_real_loss, fake_clean_reconstr, eval_loss = latent_gan.model(
            )
            G_optimizer, D_optimizer = latent_gan.optimize(G_loss, D_loss)

            # metrics for tensorboard visualization
            with tf.name_scope('metrics'):
                G_loss_mean_op, G_loss_mean_update_op = tf.metrics.mean(G_loss)
                G_tofool_loss_mean_op, G_tofool_loss_mean_update_op = tf.metrics.mean(
                    G_tofool_loss)
                reconstr_loss_mean_op, reconstr_loss_mean_update_op = tf.metrics.mean(
                    reconstr_loss)

                D_loss_mean_op, D_loss_mean_update_op = tf.metrics.mean(D_loss)
                D_fake_loss_mean_op, D_fake_loss_mean_update_op = tf.metrics.mean(
                    D_fake_loss)
                D_real_loss_mean_op, D_real_loss_mean_update_op = tf.metrics.mean(
                    D_real_loss)

                eval_loss_mean_op, eval_loss_mean_update_op = tf.metrics.mean(
                    eval_loss)

            reset_metrics = tf.variables_initializer([
                var for var in tf.local_variables()
                if var.name.split('/')[0] == 'metrics'
            ])

            tf.summary.scalar('loss/G/G_loss',
                              G_loss_mean_op,
                              collections=['train'])
            tf.summary.scalar('loss/G/G_tofool_loss',
                              G_tofool_loss_mean_op,
                              collections=['train'])
            tf.summary.scalar('loss/G/G_reconstr_loss',
                              reconstr_loss_mean_op,
                              collections=['train'])
            tf.summary.scalar('loss/D/D_loss',
                              D_loss_mean_op,
                              collections=['train'])
            tf.summary.scalar('loss/D/D_fake_loss',
                              D_fake_loss_mean_op,
                              collections=['train'])
            tf.summary.scalar('loss/D/D_real_loss',
                              D_real_loss_mean_op,
                              collections=['train'])

            tf.summary.scalar('loss/%s_loss' % (para_config_gan['eval_loss']),
                              eval_loss_mean_op,
                              collections=['test'])

            summary_op = tf.summary.merge_all('train')
            summary_eval_op = tf.summary.merge_all('test')
            train_writer = tf.summary.FileWriter(
                os.path.join(LOG_DIR, 'summary', 'train'))
            test_writer = tf.summary.FileWriter(
                os.path.join(LOG_DIR, 'summary', 'test'))
            saver = tf.train.Saver(max_to_keep=None)

        # print
        log_string('Net layers:')
        log_string(str(latent_gan))

        # Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False

        with tf.Session(config=config) as sess:
            # Init variables
            init = tf.global_variables_initializer()
            sess.run(init)

            if para_config_gan['recover_ckpt'] is not None:
                print('Continue training from %s' %
                      (para_config_gan['recover_ckpt']))
                saver.restore(sess, para_config_gan['recover_ckpt'])
            else:
                # NOTE: load pre-trained AE weights
                # noisy AE, only pre-trained encoder is used
                noisy_ckpt_vars = tf.contrib.framework.list_variables(
                    para_config_gan['noisy_ae_ckpt'])
                '''
                print('Noisy AE pre-trained variables:')
                for vname, _ in noisy_ckpt_vars:
                    print(vname)
                '''
                restore_dict = get_restore_dict(
                    noisy_ckpt_vars, latent_gan.noisy_encoder.all_variables)
                noisy_saver = tf.train.Saver(restore_dict)
                noisy_saver.restore(sess, para_config_gan['noisy_ae_ckpt'])
                # clean AE, both pre-trained encoder and decoder are used
                clean_ckpt_vars = tf.contrib.framework.list_variables(
                    para_config_gan['clean_ae_ckpt'])
                '''
                print('Clean AE pre-trained variables:')
                for vname, _ in clean_ckpt_vars:
                    print(vname)
                '''
                restore_dict = get_restore_dict(
                    clean_ckpt_vars, latent_gan.clean_encoder.all_variables)
                clean_saver = tf.train.Saver(restore_dict)
                clean_saver.restore(sess, para_config_gan['clean_ae_ckpt'])
                restore_dict = get_restore_dict(
                    clean_ckpt_vars, latent_gan.clean_decoder.all_variables)
                clean_saver = tf.train.Saver(restore_dict)
                clean_saver.restore(sess, para_config_gan['clean_ae_ckpt'])
                print('Loading pre-trained noisy/clean AE done.')
                # END of weights loading

            for i in range(para_config_gan['epoch']):
                sess.run(reset_metrics)

                while NOISY_TRAIN_DATASET.has_next_batch(
                ) and CLEAN_TRAIN_DATASET.has_next_batch():
                    noise_cur = NOISY_TRAIN_DATASET.next_batch()
                    clean_cur = CLEAN_TRAIN_DATASET.next_batch()

                    feed_dict = {
                        latent_gan.input_noisy_cloud: noise_cur,
                        latent_gan.input_clean_cloud: clean_cur,
                        latent_gan.is_training: True,
                    }
                    # train D for k times
                    for _ in range(para_config_gan['k']):
                        sess.run([
                            D_optimizer, D_fake_loss_mean_update_op,
                            D_real_loss_mean_update_op, D_loss_mean_update_op
                        ],
                                 feed_dict=feed_dict)

                    # train G
                    for _ in range(para_config_gan['kk']):
                        sess.run([
                            G_optimizer, G_tofool_loss_mean_update_op,
                            reconstr_loss_mean_update_op, G_loss_mean_update_op
                        ],
                                 feed_dict=feed_dict)

                NOISY_TRAIN_DATASET.reset()
                CLEAN_TRAIN_DATASET.reset()

                if i % para_config_gan['output_interval'] == 0:
                    G_loss_mean_val, G_tofool_loss_mean_val, \
                    reconstr_loss_mean_val, \
                    D_loss_mean_val, D_fake_loss_mean_val, D_real_loss_mean_val, \
                    fake_clean_reconstr_val, summary = \
                    sess.run([G_loss_mean_op, G_tofool_loss_mean_op, \
                              reconstr_loss_mean_op, \
                              D_loss_mean_op, D_fake_loss_mean_op, D_real_loss_mean_op, \
                              fake_clean_reconstr, summary_op],
                              feed_dict=feed_dict)

                    # save currently generated
                    if i % para_config_gan['save_interval'] == 0:
                        pc_util.write_ply_batch(
                            fake_clean_reconstr_val,
                            os.path.join(LOG_DIR, 'fake_cleans',
                                         'reconstr_%d' % (i)))
                        pc_util.write_ply_batch(
                            noise_cur,
                            os.path.join(LOG_DIR, 'fake_cleans',
                                         'input_noisy_%d' % (i)))
                        pc_util.write_ply_batch(
                            clean_cur,
                            os.path.join(LOG_DIR, 'fake_cleans',
                                         'input_clean_%d' % (i)))

                    # terminal prints
                    log_string(
                        '%s training %d snapshot: ' %
                        (datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), i))
                    log_string(
                        '        G loss: {:.6f} = (g){:.6f}, (r){:.6f}'.format(
                            G_loss_mean_val, G_tofool_loss_mean_val,
                            reconstr_loss_mean_val))
                    log_string(
                        '        D loss: {:.6f} = (f){:.6f}, (r){:.6f}'.format(
                            D_loss_mean_val, D_fake_loss_mean_val,
                            D_real_loss_mean_val))

                    # tensorboard output
                    train_writer.add_summary(summary, i)

                if i % para_config_gan['save_interval'] == 0:
                    # test and evaluate on test set
                    NOISY_TEST_DATASET.reset()
                    while NOISY_TEST_DATASET.has_next_batch():
                        noise_cur = NOISY_TEST_DATASET.next_batch()
                        clean_cur = noise_cur
                        feed_dict = {
                            latent_gan.input_noisy_cloud: noise_cur,
                            latent_gan.gt: clean_cur,
                            latent_gan.is_training: False,
                        }
                        fake_clean_reconstr_val, _ = sess.run(
                            [fake_clean_reconstr, eval_loss_mean_update_op],
                            feed_dict=feed_dict)

                    NOISY_TEST_DATASET.reset()

                    eval_loss_mean_val, summary_eval = sess.run(
                        [eval_loss_mean_op, summary_eval_op],
                        feed_dict=feed_dict)

                    test_writer.add_summary(summary_eval, i)
                    log_string('Eval loss (%s) on test set: %f' %
                               (para_config_gan['eval_loss'],
                                np.mean(eval_loss_mean_val)))

                    # save model
                    save_path = saver.save(
                        sess,
                        os.path.join(LOG_DIR, 'ckpts', 'model_%d.ckpt' % (i)))
                    log_string("Model saved in file: %s" % save_path)
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(0)):
            ae = autoencoder.AutoEncoder(paras=para_config)
            print_trainable_vars()

            reconstr_loss, reconstr, _ = ae.model()
            optimizer = ae.make_optimizer(reconstr_loss)

            # metrics for tensorboard visualization
            with tf.name_scope('metrics'):
                reconstr_loss_mean, reconstr_loss_mean_update = tf.metrics.mean(
                    reconstr_loss)
            reset_metrics = tf.variables_initializer([
                var for var in tf.local_variables()
                if var.name.split('/')[0] == 'metrics'
            ])

            tf.summary.scalar('loss/train',
                              reconstr_loss_mean,
                              collections=['train'])
            tf.summary.scalar('loss/test',
                              reconstr_loss_mean,
                              collections=['test'])

            summary_op = tf.summary.merge_all('train')
            summary_test_op = tf.summary.merge_all('test')
            train_writer = tf.summary.FileWriter(
                os.path.join(LOG_DIR, 'summary', 'train'))
            test_writer = tf.summary.FileWriter(
                os.path.join(LOG_DIR, 'summary', 'test'))
            saver = tf.train.Saver(max_to_keep=None)

        # print
        log_string('Net layers:')
        log_string(str(ae))

        # Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        # Init variables
        init = tf.global_variables_initializer()
        sess.run(init)
        sess.run(reset_metrics)

        total_batch_idx = 0
        for ep_idx in range(para_config['epoch']):
            log_string('-----------Epoch %d:-------------' % ep_idx)

            # train one epoch
            while TRAIN_DATASET.has_next_batch():
                sess.run(reset_metrics)

                if para_config['ae_type'] == 'c2c':
                    input_batch = TRAIN_DATASET.next_batch()
                elif para_config['ae_type'] == 'n2n':
                    input_batch = TRAIN_DATASET.next_batch_noise_added(
                        noise_mu=para_config['noise_mu'],
                        noise_sigma=para_config['noise_sigma'])
                elif para_config['ae_type'] == 'np2np':
                    input_batch = TRAIN_DATASET.next_batch_noise_partial_by_percentage(
                        noise_mu=para_config['noise_mu'],
                        noise_sigma=para_config['noise_sigma'],
                        p_min=para_config['p_min'],
                        p_max=para_config['p_max'],
                        partial_portion=para_config['partial_portion'])
                else:
                    log_string('Unknown ae type: %s' %
                               (para_config['ae_type']))
                    exit

                if para_config['data_aug'] is not None:
                    input_batch = TRAIN_DATASET.aug_data_batch(
                        input_batch,
                        scale_low=para_config['data_aug']['scale_low'],
                        scale_high=para_config['data_aug']['scale_high'],
                        rot=para_config['data_aug']['rot'],
                        snap2ground=para_config['data_aug']['snap2ground'],
                        trans=para_config['data_aug']['trans'])

                _, _ = sess.run([optimizer, reconstr_loss_mean_update],
                                feed_dict={
                                    ae.input_pl: input_batch,
                                    ae.is_training: True
                                })

                if TRAIN_DATASET.batch_idx % para_config[
                        'output_interval'] == 0:
                    reconstr_loss_mean_val, summary = sess.run(
                        [reconstr_loss_mean, summary_op])
                    sess.run(reset_metrics)
                    log_string(
                        '-----------batch %d statistics snapshot:-------------'
                        % TRAIN_DATASET.batch_idx)
                    log_string('  Reconstruction loss   : {:.6f}'.format(
                        reconstr_loss_mean_val))

                    train_writer.add_summary(summary, total_batch_idx)
                    train_writer.flush()

                total_batch_idx += 1

            # after each epoch, reset
            TRAIN_DATASET.reset()

            # test and save
            if ep_idx % para_config['save_interval'] == 0:
                # test on whole test dataset
                sess.run(reset_metrics)
                TEST_DATASET.reset()
                while TEST_DATASET.has_next_batch():

                    if para_config['ae_type'] == 'c2c':
                        input_batch_test = TEST_DATASET.next_batch()
                    elif para_config['ae_type'] == 'n2n':
                        input_batch_test = TEST_DATASET.next_batch_noise_added(
                            noise_mu=para_config['noise_mu'],
                            noise_sigma=para_config['noise_sigma'])
                    elif para_config['ae_type'] == 'np2np':
                        input_batch_test = TEST_DATASET.next_batch_noise_partial_by_percentage(
                            noise_mu=para_config['noise_mu'],
                            noise_sigma=para_config['noise_sigma'],
                            p_min=para_config['p_min'],
                            p_max=para_config['p_max'],
                            partial_portion=para_config['partial_portion'])
                    else:
                        log_string('Unknown ae type: %s' %
                                   (para_config['ae_type']))
                        exit

                    if para_config['data_aug'] is not None:
                        input_batch_test = TEST_DATASET.aug_data_batch(
                            input_batch_test,
                            scale_low=para_config['data_aug']['scale_low'],
                            scale_high=para_config['data_aug']['scale_high'],
                            rot=para_config['data_aug']['rot'],
                            snap2ground=para_config['data_aug']['snap2ground'],
                            trans=para_config['data_aug']['trans'])

                    reconstr_val_test, _ = sess.run(
                        [reconstr, reconstr_loss_mean_update],
                        feed_dict={
                            ae.input_pl: input_batch_test,
                            ae.is_training: False
                        })

                log_string('--------- on test split: --------')
                reconstr_loss_mean_val, summary_test = sess.run(
                    [reconstr_loss_mean, summary_test_op])
                log_string('Mean Reconstruction loss: {:.6f}'.format(
                    reconstr_loss_mean_val))
                sess.run(reset_metrics)  # reset metrics

                # tensorboard
                test_writer.add_summary(summary_test, ep_idx)
                test_writer.flush()

                # write out only one batch for check
                if False:
                    pc_util.write_ply_batch(
                        np.asarray(reconstr_val_test),
                        os.path.join(LOG_DIR, 'pcloud',
                                     'reconstr_%d' % (ep_idx)))
                    pc_util.write_ply_batch(
                        np.asarray(input_batch_test),
                        os.path.join(LOG_DIR, 'pcloud', 'input_%d' % (ep_idx)))

                # save model
                save_path = saver.save(
                    sess,
                    os.path.join(LOG_DIR, 'ckpts', 'model_%d.ckpt' % (ep_idx)))
                log_string("Model saved in file: %s" % save_path)