Ejemplo n.º 1
0
def test(raw_dirs, TEST_DATASET):
    log_string(FLAGS.log_dir)
    with tf.Graph().as_default():
        with tf.device('/gpu:0'):
            input_pls = model.placeholder_inputs(scope='inputs_pl',
                                                 FLAGS=FLAGS,
                                                 num_pnts=NUM_INPUT_POINTS)
            is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=())
            print(is_training_pl)
            batch = tf.Variable(0, name='batch')

            print("--- Get model and loss")
            # Get model and loss

            end_points = model.get_model(input_pls,
                                         is_training_pl,
                                         bn=False,
                                         FLAGS=FLAGS)
            loss, end_points = model.get_loss(end_points, FLAGS=FLAGS)
            gpu_options = tf.compat.v1.GPUOptions(
            )  # (per_process_gpu_memory_fraction=0.99)
            config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
            config.gpu_options.allow_growth = True
            config.allow_soft_placement = True
            config.log_device_placement = False
            sess = tf.compat.v1.Session(config=config)

            ##### all
            # Init variables
            init = tf.compat.v1.global_variables_initializer()
            sess.run(init)

            ######### Loading Checkpoint ###############
            # Overall
            saver = tf.compat.v1.train.Saver([
                v for v in tf.compat.v1.get_collection_ref(
                    tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
                if ('lr' not in v.name) and ('batch' not in v.name)
            ])

            ###########################################

            ops = {
                'input_pls': input_pls,
                'is_training_pl': is_training_pl,
                'loss': loss,
                'step': batch,
                'saver': saver,
                'end_points': end_points
            }
            sys.stdout.flush()

            TEST_DATASET.start()
            log_string('**** INFERENCE  ****')
            sys.stdout.flush()
            test_one_epoch(sess, ops, raw_dirs, TEST_DATASET)
            TEST_DATASET.shutdown()
            print("Done!")
Ejemplo n.º 2
0
def create():
    log_string(LOG_DIR)

    batch_data = read_img_get_transmat()

    input_pls = model.placeholder_inputs(BATCH_SIZE, NUM_POINTS, (IMG_SIZE, IMG_SIZE),
                        num_sample_pc=NUM_SAMPLE_POINTS, scope='inputs_pl', FLAGS=FLAGS)
    is_training_pl = tf.placeholder(tf.bool, shape=())
    print(is_training_pl)
    batch = tf.Variable(0, name='batch')

    print("--- Get model and loss")
    # Get model and loss

    end_points = model.get_model(input_pls, NUM_POINTS, is_training_pl, bn=False,FLAGS=FLAGS)

    loss, end_points = model.get_loss(end_points,
        sdf_weight=SDF_WEIGHT, num_sample_points=NUM_SAMPLE_POINTS, FLAGS=FLAGS)
    # Create a session
    gpu_options = tf.GPUOptions() # per_process_gpu_memory_fraction=0.99
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)

    init = tf.global_variables_initializer()
    sess.run(init)

    ######### Loading Checkpoint ###############
    saver = tf.train.Saver([v for v in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) if
                            ('lr' not in v.name) and ('batch' not in v.name)])
    ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)

    if ckptstate is not None:
        LOAD_MODEL_FILE = os.path.join(PRETRAINED_MODEL_PATH, os.path.basename(ckptstate.model_checkpoint_path))
        try:
            # load_model(sess, PRETRAINED_PN_MODEL_FILE, ['refpc_reconstruction','sdfprediction','vgg_16'], strict=True)
            with NoStdStreams():
                saver.restore(sess, LOAD_MODEL_FILE)
            print("Model loaded in file: %s" % LOAD_MODEL_FILE)
        except:
            print("Fail to load overall modelfile: %s" % PRETRAINED_MODEL_PATH)

    ###########################################

    ops = {'input_pls': input_pls,
           'is_training_pl': is_training_pl,
           'loss': loss,
           'step': batch,
           'end_points': end_points}

    test_one_epoch(sess, ops, batch_data)
Ejemplo n.º 3
0
def train():
    log_string(LOG_DIR)
    with tf.Graph().as_default():
        with tf.device('/gpu:0'):
            input_pls = model.placeholder_inputs(
                BATCH_SIZE,
                NUM_POINTS, (IMG_SIZE, IMG_SIZE),
                num_sample_pc=NUM_SAMPLE_POINTS,
                scope='inputs_pl',
                FLAGS=FLAGS)
            is_training_pl = tf.placeholder(tf.bool, shape=())
            print(is_training_pl)

            # Note the global_step=batch parameter to minimize.
            # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
            batch = tf.Variable(0, name='batch')
            bn_decay = get_bn_decay(batch)
            # tf.summary.scalar('bn_decay', bn_decay)

            print("--- Get model and loss")
            # Get model and loss

            end_points = model.get_model(input_pls,
                                         NUM_POINTS,
                                         is_training_pl,
                                         bn=False,
                                         FLAGS=FLAGS)
            loss, end_points = model.get_loss(
                end_points,
                sdf_weight=SDF_WEIGHT,
                mask_weight=FLAGS.mask_weight,
                num_sample_points=FLAGS.num_sample_points,
                FLAGS=FLAGS)
            # tf.summary.scalar('loss', loss)

            print("--- Get training operator")
            # Get training operator
            learning_rate = get_learning_rate(batch)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate,
                                                   beta1=FLAGS.beta1)

            # Create a session
            config = tf.ConfigProto()
            gpu_options = tf.GPUOptions(
            )  #(per_process_gpu_memory_fraction=0.99)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.gpu_options.allow_growth = True
            config.allow_soft_placement = True
            config.log_device_placement = False
            sess = tf.Session(config=config)

            merged = None
            train_writer = None

            ##### all
            update_variables = [
                x for x in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
            ]

            train_op = optimizer.minimize(loss,
                                          global_step=batch,
                                          var_list=update_variables)

            # Init variables
            init = tf.global_variables_initializer()
            sess.run(init)

            ######### Loading Checkpoint ###############
            # CNN(Pretrained from ImageNet)
            if PRETRAINED_CNN_MODEL_FILE is not '':
                if not load_model(
                        sess, PRETRAINED_CNN_MODEL_FILE, 'vgg_16',
                        strict=True):
                    return

            if PRETRAINED_PN_MODEL_FILE is not '':
                if not load_model(sess,
                                  PRETRAINED_PN_MODEL_FILE,
                                  ['refpc_reconstruction', 'sdfprediction'],
                                  strict=True):
                    return
                # Overall
            saver = tf.train.Saver([
                v for v in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
                if ('lr' not in v.name) and ('batch' not in v.name)
            ])
            ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)

            if ckptstate is not None:
                LOAD_MODEL_FILE = os.path.join(
                    PRETRAINED_MODEL_PATH,
                    os.path.basename(ckptstate.model_checkpoint_path))
                try:
                    load_model(sess,
                               LOAD_MODEL_FILE, [
                                   'sdfprediction/fold1',
                                   'sdfprediction/fold2', 'vgg_16'
                               ],
                               strict=True)
                    # load_model(sess, LOAD_MODEL_FILE, ['sdfprediction','vgg_16'], strict=True)
                    with NoStdStreams():
                        saver.restore(sess, LOAD_MODEL_FILE)
                    print("Model loaded in file: %s" % LOAD_MODEL_FILE)
                except:
                    print("Fail to load overall modelfile: %s" %
                          PRETRAINED_MODEL_PATH)

            ###########################################

            ops = {
                'input_pls': input_pls,
                'is_training_pl': is_training_pl,
                'loss': loss,
                'train_op': train_op,
                'merged': merged,
                'step': batch,
                'lr': learning_rate,
                'end_points': end_points
            }

            best_loss = 1e20
            TRAIN_DATASET.start()
            best_acc = 0
            for epoch in range(MAX_EPOCH):
                log_string('**** EPOCH %03d ****' % (epoch))
                sys.stdout.flush()

                avg_accuracy = train_one_epoch(sess, ops, train_writer, saver)

                # Save the variables to disk.
                if avg_accuracy > best_acc:
                    best_acc = avg_accuracy
                    save_path = saver.save(sess,
                                           os.path.join(LOG_DIR, "model.ckpt"))
                    log_string("best Model saved in file: %s" % save_path)
                elif epoch % 10 == 0:
                    save_path = saver.save(
                        sess,
                        os.path.join(LOG_DIR,
                                     "model_epoch_%03d.ckpt" % (epoch)))
                    log_string("Model saved in file: %s" % save_path)

            TRAIN_DATASET.shutdown()
Ejemplo n.º 4
0
def train():
    log_string(FLAGS.log_dir)
    with tf.Graph().as_default():
        with tf.device('/gpu:{}'.format(0)):
            input_pls = model.placeholder_inputs(scope='inputs_pl',
                                                 FLAGS=FLAGS)
            is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=())
            print(is_training_pl)

            # Note the global_step=batch parameter to minimize.
            # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
            batch = tf.Variable(0, name='batch')
            bn_decay = get_bn_decay(batch)
            # tf.summary.scalar('bn_decay', bn_decay)

            print("--- Get model and loss")
            # Get model and loss

            end_points = model.get_model(input_pls,
                                         is_training_pl,
                                         bn=FLAGS.bn,
                                         bn_decay=bn_decay,
                                         FLAGS=FLAGS)
            loss, end_points = model.get_loss(end_points, FLAGS=FLAGS)
            # tf.summary.scalar('loss', loss)

            print("--- Get training operator")
            # Get training operator
            learning_rate = get_learning_rate(batch)
            if FLAGS.optimizer == 'momentum':
                optimizer = tf.compat.v1.train.MomentumOptimizer(
                    learning_rate, momentum=FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate,
                                                             beta1=FLAGS.beta1)

            # Create a session
            gpu_options = tf.compat.v1.GPUOptions(
            )  #(per_process_gpu_memory_fraction=0.99)
            config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
            config.gpu_options.allow_growth = True
            config.allow_soft_placement = True
            config.log_device_placement = False
            sess = tf.compat.v1.Session(config=config)

            merged = None
            train_writer = None

            ##### all
            update_variables = [
                x for x in tf.compat.v1.get_collection_ref(
                    tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
            ]

            train_op = optimizer.minimize(loss,
                                          global_step=batch,
                                          var_list=update_variables)

            # Init variables
            init = tf.compat.v1.global_variables_initializer()
            sess.run(init)

            ######### Loading Checkpoint ###############
            # CNN(Pretrained from ImageNet)
            if FLAGS.restore_modelcnn is not '':
                if not load_model(
                        sess, FLAGS.restore_modelcnn, FLAGS.encoder,
                        strict=True):
                    return
                # Overall
            saver = tf.compat.v1.train.Saver([
                v for v in tf.compat.v1.get_collection_ref(
                    tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
                if ('lr' not in v.name) and ('batch' not in v.name)
            ])
            ckptstate = tf.train.get_checkpoint_state(FLAGS.restore_model)

            if ckptstate is not None:
                LOAD_MODEL_FILE = os.path.join(
                    FLAGS.restore_model,
                    os.path.basename(ckptstate.all_model_checkpoint_paths[0]))
                load_model_all(saver, sess, LOAD_MODEL_FILE)
                print("Model loaded in file: %s" % LOAD_MODEL_FILE)
                # try:
                #     load_model(sess, LOAD_MODEL_FILE, ['sdfprediction/fold1', 'sdfprediction/fold2', FLAGS.encoder],
                #                strict=True)
                #     # load_model(sess, LOAD_MODEL_FILE, ['sdfprediction','vgg_16'], strict=True)
                #     with NoStdStreams():
                #         saver.restore(sess, LOAD_MODEL_FILE)
                #     print("Model loaded in file: %s" % LOAD_MODEL_FILE)
                # except:
                #     print("Fail to load overall modelfile: %s" % FLAGS.restore_model)

            ###########################################

            ops = {
                'input_pls': input_pls,
                'is_training_pl': is_training_pl,
                'loss': loss,
                'train_op': train_op,
                'merged': merged,
                'step': batch,
                'lr': learning_rate,
                'end_points': end_points
            }

            TRAIN_DATASET.start()
            TEST_DATASET.start()
            best_locnorm_diff, best_dir_diff = 10000, 10000
            for epoch in range(FLAGS.max_epoch):
                log_string('**** EPOCH %03d ****' % (epoch))
                sys.stdout.flush()
                if epoch == 0 and FLAGS.restore_model:
                    test_one_epoch(sess, ops, epoch)
                # test_one_epoch(sess, ops, epoch)
                xyz_avg_diff, _, _ = train_one_epoch(sess, ops, epoch)
                if epoch % 5 == 0 and epoch > 1:
                    locnorm_avg_diff, direction_avg_diff = test_one_epoch(
                        sess, ops, epoch)
                    # Save the variables to disk.
                    if locnorm_avg_diff < best_locnorm_diff:
                        best_locnorm_diff = locnorm_avg_diff
                        save_path = saver.save(
                            sess, os.path.join(FLAGS.log_dir, "model.ckpt"))
                        log_string(
                            "best locnorm_avg_diff Model saved in file: %s" %
                            save_path)
                    # elif direction_avg_diff < best_dir_diff:
                    #     best_dir_diff = direction_avg_diff
                    #     save_path = saver.save(sess, os.path.join(FLAGS.log_dir, "dir_model.ckpt"))
                    #     log_string("best direction Model saved in file: %s" % save_path)
                if epoch % 30 == 0 and epoch > 1:
                    save_path = saver.save(
                        sess,
                        os.path.join(FLAGS.log_dir,
                                     "model_epoch_%03d.ckpt" % (epoch)))
                    log_string("Model saved in file: %s" % save_path)

            TRAIN_DATASET.shutdown()
            TEST_DATASET.shutdown()