Esempio n. 1
0
def main():
    # capture the config path from the run arguments
    # then process the json configuration file
    try:
        args = get_args()
        config = process_config(args.config)

        config.old_tfrecords = args.old_tfrecords
        config.normalize_data = False

    except Exception as e:
        print("An error occurred during processing the configuration file")
        print(e)
        exit(0)

        # create the experiments dirs
    create_dirs(
        [config.summary_dir, config.checkpoint_dir, config.config_file_dir])

    # create tensorflow session
    sess = tf.Session()

    # create your data generator
    #train_data = DataGenerator(config, sess, train=True)

    test_data = DataGenerator(config, sess, train=False)
    next_element = test_data.get_next_batch()
    dir_name = "test"

    while True:
        try:
            features = sess.run(next_element)
            features = convert_dict_to_list_subdicts(features,
                                                     config.test_batch_size)
            features = features[0]
            summaries = create_target_summary_dicts(features)

            dir_path = check_exp_folder_exists_and_create(features, dir_name)
            if dir_path is not None:
                export_summary_images(config=config,
                                      summaries_dict_images=summaries,
                                      dir_path=dir_path)

        except tf.errors.OutOfRangeError:
            print("done exporting")
            break
def main():
    # create tensorflow session
    sess = tf.Session()

    try:
        args = get_args()
        config = process_config(args.config)

        config.old_tfrecords = args.old_tfrecords
        config.normalize_data = args.normalize_data

    except Exception as e:
        print("An error occurred during processing the configuration file")
        print(e)
        exit(0)

    # create the experiments dirs
    create_dirs(
        [config.summary_dir, config.checkpoint_dir, config.config_file_dir])

    # create your data generator
    train_data = DataGenerator(config, sess, train=True)
    test_data = DataGenerator(config, sess, train=False)

    logger = Logger(sess, config)

    print("using {} rollout steps".format(config.n_rollouts))

    inp_rgb = tf.placeholder("float", [None, 120, 160, 7])
    control = tf.placeholder("float", [None, 6])
    gt_seg = tf.placeholder("float", [None, 120, 160])

    pred = cnnmodel(inp_rgb, control)

    predictions = tf.reshape(
        pred, [-1, pred.get_shape()[1] * pred.get_shape()[2]])
    labels = tf.reshape(gt_seg,
                        [-1, gt_seg.get_shape()[1] * gt_seg.get_shape()[2]])

    global_step_tensor = tf.Variable(0, trainable=False, name='global_step')

    loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                logits=predictions))
    optimizer = tf.train.AdamOptimizer(
        learning_rate=config.learning_rate).minimize(
            loss, global_step=global_step_tensor)

    with tf.variable_scope('cur_epoch'):
        cur_epoch_tensor = tf.Variable(0, trainable=False, name='cur_epoch')
        increment_cur_epoch_tensor = tf.assign(cur_epoch_tensor,
                                               cur_epoch_tensor + 1)

    with tf.variable_scope('global_step'):
        cur_batch_tensor = tf.Variable(0, trainable=False, name='cur_batch')
        increment_cur_batch_tensor = tf.assign(cur_batch_tensor,
                                               cur_batch_tensor + 1)

    next_element_train = train_data.get_next_batch()
    next_element_test = test_data.get_next_batch()

    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    sess.run(init)

    saver = tf.train.Saver(max_to_keep=config.max_checkpoints_to_keep)

    latest_checkpoint = tf.train.latest_checkpoint(config.checkpoint_dir)
    if latest_checkpoint:
        print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
        print("Model loaded")

    def _process_rollouts(feature, train=True):
        gt_merged_seg_rollout_batch = []
        input_merged_images_rollout_batch = []
        gripper_pos_vel_rollout_batch = []
        for step in range(config.n_rollouts - 1):
            if step < feature["unpadded_experiment_length"]:
                obj_segments = feature["object_segments"][step]
                """ transform (3,120,160,7) into (1,120,160,7) by merging the rgb,depth and seg masks """
                input_merged_images = create_full_images_of_object_masks(
                    obj_segments)

                obj_segments_gt = feature["object_segments"][step + 1]
                gt_merged_seg = create_full_images_of_object_masks(
                    obj_segments_gt)[:, :, 3]

                gripper_pos = feature["gripperpos"][step + 1]
                gripper_vel = feature["grippervel"][step + 1]
                gripper_pos_vel = np.concatenate([gripper_pos, gripper_vel])

                gt_merged_seg_rollout_batch.append(gt_merged_seg)
                input_merged_images_rollout_batch.append(input_merged_images)
                gripper_pos_vel_rollout_batch.append(gripper_pos_vel)

        if train:
            retrn = sess.run(
                [optimizer, loss, pred],
                feed_dict={
                    inp_rgb: input_merged_images_rollout_batch,
                    control: gripper_pos_vel_rollout_batch,
                    gt_seg: gt_merged_seg_rollout_batch
                })
            return retrn[1], retrn[2]

        else:
            retrn = sess.run(
                [loss, pred],
                feed_dict={
                    inp_rgb: input_merged_images_rollout_batch,
                    control: gripper_pos_vel_rollout_batch,
                    gt_seg: gt_merged_seg_rollout_batch
                })
            """ sigmoid cross entropy runs logits through sigmoid but only during train time """
            seg_data = sigmoid(retrn[1])
            seg_data[seg_data >= 0.5] = 1.0
            seg_data[seg_data < 0.5] = 0.0
            return retrn[0], seg_data, gt_merged_seg_rollout_batch

    for cur_epoch in range(cur_epoch_tensor.eval(sess), config.n_epochs + 1,
                           1):
        while True:
            try:
                features = sess.run(next_element_train)
                features = convert_dict_to_list_subdicts(
                    features, config.train_batch_size)
                loss_batch = []
                sess.run(increment_cur_batch_tensor)
                for _ in range(config.train_batch_size):
                    for feature in features:
                        loss_train, _ = _process_rollouts(feature)
                        loss_batch.append([loss_train])

                cur_batch_it = cur_batch_tensor.eval(sess)
                loss_mean_batch = np.mean(loss_batch)

                print('train loss batch {0:} is: {1:.4f}'.format(
                    cur_batch_it, loss_mean_batch))
                summaries_dict = {config.exp_name + '_loss': loss_mean_batch}
                logger.summarize(cur_batch_it,
                                 summaries_dict=summaries_dict,
                                 summarizer="train")

                if cur_batch_it % config.test_interval == 1:
                    print("Executing test batch")
                    features_idx = 0  # always take first element for testing
                    features = sess.run(next_element_test)
                    features = convert_dict_to_list_subdicts(
                        features, config.test_batch_size)
                    loss_test_batch = []

                    for i in range(config.test_batch_size):
                        loss_test, seg_data, gt_seg_data = _process_rollouts(
                            features[features_idx], train=False)
                        loss_test_batch.append(loss_test)

                    loss_test_mean_batch = np.mean(loss_test_batch)
                    summaries_dict = {
                        config.exp_name + '_test_loss': loss_test_mean_batch
                    }
                    logger.summarize(cur_batch_it,
                                     summaries_dict=summaries_dict,
                                     summarizer="test")

                    print('test loss is: {0:.4f}'.format(loss_test_mean_batch))
                    if seg_data is not None and gt_seg_data is not None:
                        """ create gif here """
                        create_seg_gif(features,
                                       features_idx,
                                       config,
                                       seg_data,
                                       gt_seg_data,
                                       dir_name="tests_during_training",
                                       cur_batch_it=cur_batch_it)

                if cur_batch_it % config.model_save_step_interval == 1:
                    print("Saving model...")
                    saver.save(sess, config.checkpoint_dir, global_step_tensor)
                    print("Model saved")

            except tf.errors.OutOfRangeError:
                break

        sess.run(increment_cur_epoch_tensor)

        return None