def post_assign_to_tensorboard(orig_assign,
                               preped_assigns,
                               network_heads,
                               feed_dict,
                               itr,
                               sess,
                               writer,
                               blob,
                               valid=0):

    gt_visuals = []
    map_visuals = []
    # post scalar summary per assign, store fetched maps
    for i in range(len(preped_assigns)):
        assign = orig_assign[i]
        _, _, _, scalar_summary_op, images_summary_op, images_placeholders, _ = preped_assigns[
            i]
        fetch_list = [scalar_summary_op[valid]]
        # fetch sub_predicitons
        nr_feature_maps = len(network_heads[assign["stamp_func"][0]][
            assign["stamp_args"]["loss"]])

        [
            fetch_list.append(network_heads[assign["stamp_func"][0]][
                assign["stamp_args"]["loss"]][nr_feature_maps - (x + 1)])
            for x in range(len(assign["ds_factors"]))
        ]

        summary = sess.run(fetch_list, feed_dict=feed_dict)
        writer.add_summary(summary[0], float(itr))

        gt_visual = get_gt_visuals(blob,
                                   assign,
                                   i,
                                   pred_boxes=None,
                                   show=False)
        map_visual = get_map_visuals(summary[1:], assign, show=False)
        gt_visuals.append(gt_visual)
        map_visuals.append(map_visual)

    # stitch one large image out of all assigns
    stitched_img = get_stitched_tensorboard_image(orig_assign, gt_visuals,
                                                  map_visuals, blob, itr)
    stitched_img = np.expand_dims(stitched_img, 0)
    #obsolete
    #images_feed_dict = get_images_feed_dict(assign, blob, gt_visuals, map_visuals, images_placeholders)
    images_feed_dict = dict()
    images_feed_dict[images_placeholders[valid]] = stitched_img

    # save images to tensorboard
    summary = sess.run([images_summary_op[valid]], feed_dict=images_feed_dict)
    writer.add_summary(summary[0], float(itr))

    return None
Exemplo n.º 2
0
def post_assign_to_tensorboard(assign, assign_nr, scalar_summary_op,
                               network_heads, feed_dict, itr, sess, writer,
                               blob, images_placeholders, images_summary_op):
    fetch_list = [scalar_summary_op]
    # fetch sub_predicitons
    nr_feature_maps = len(
        network_heads[assign["stamp_func"][0]][assign["stamp_args"]["loss"]])

    [
        fetch_list.append(network_heads[assign["stamp_func"][0]][
            assign["stamp_args"]["loss"]][nr_feature_maps - (x + 1)])
        for x in range(len(assign["ds_factors"]))
    ]

    summary = sess.run(fetch_list, feed_dict=feed_dict)
    writer.add_summary(summary[0], float(itr))

    # use predicted feature maps
    # TODO predict boxes

    # debug logits
    # if itr ==1:
    #     hist_ph = tf.placeholder(tf.uint8, shape=[1, summary[1].shape[3]])
    #     logits_sum = tf.summary.histogram('logits_means', hist_ph)
    #
    # h_sum = sess.run([logits_sum], feed_dict={hist_ph: np.mean(summary[1],(1,2))})
    # writer.add_summary(h_sum[0], float(itr))

    gt_visuals = get_gt_visuals(blob,
                                assign,
                                assign_nr,
                                pred_boxes=None,
                                show=False)
    map_visuals = get_map_visuals(summary[1:], assign, show=False)
    images_feed_dict = get_images_feed_dict(assign, blob, gt_visuals,
                                            map_visuals, images_placeholders)
    # save images to tensorboard
    summary = sess.run([images_summary_op], feed_dict=images_feed_dict)
    writer.add_summary(summary[0], float(itr))
    return None
Exemplo n.º 3
0
def execute_assign(args, input, saver, sess, checkpoint_dir, checkpoint_name,
                   data_layer, writer, network_heads, do_itr, assign,
                   prepped_assign, iteration, training_help):
    loss, optim, gt_placeholders, scalar_summary_op, images_summary_op, images_placeholders, mask_placeholders = prepped_assign

    if args.prefetch == "True":
        data_layer = PrefetchWrapper(data_layer.forward, args.prefetch_len,
                                     args, [assign], training_help)

    print("training on:" + str(assign))
    print("for " + str(do_itr) + " iterations")
    for itr in range(iteration, (iteration + do_itr)):
        # load batch - only use batches with content
        batch_not_loaded = True
        while batch_not_loaded:
            blob = data_layer.forward(args, [assign], training_help)
            if int(gt_placeholders[0].shape[-1]
                   ) != blob["assign0"]["gt_map0"].shape[-1] or len(
                       blob["gt_boxes"].shape) != 3:
                print("skipping queue element")
            else:
                batch_not_loaded = False

        if blob["helper"] is not None:
            input_data = np.concatenate([blob["data"], blob["helper"]], -1)
            feed_dict = {input: input_data}
        else:
            # pad input with zeros
            # input_data = np.concatenate([blob["data"]*0, blob["data"]*0], -1)
            # feed_dict = {input: blob["data"], helper_input: input_data}
            if len(args.training_help) == 1:
                feed_dict = {input: blob["data"]}
            else:
                # pad input with zeros
                input_data = np.concatenate([blob["data"], blob["data"] * 0],
                                            -1)
                feed_dict = {input: input_data}

        for i in range(len(gt_placeholders)):
            # only one assign
            feed_dict[gt_placeholders[i]] = blob["assign0"][
                "gt_map" + str(len(gt_placeholders) - i - 1)]
            feed_dict[mask_placeholders[i]] = blob["assign0"][
                "mask" + str(len(gt_placeholders) - i - 1)]

        # train step
        _, loss_fetch = sess.run([optim, loss], feed_dict=feed_dict)

        if itr % args.print_interval == 0 or itr == 1:
            print("loss at itr: " + str(itr))
            print(loss_fetch)

        if itr % args.tensorboard_interval == 0 or itr == 1:
            fetch_list = [scalar_summary_op]
            # fetch sub_predicitons
            nr_feature_maps = len(network_heads[assign["stamp_func"][0]][
                assign["stamp_args"]["loss"]])

            [
                fetch_list.append(network_heads[assign["stamp_func"][0]][
                    assign["stamp_args"]["loss"]][nr_feature_maps - (x + 1)])
                for x in range(len(assign["ds_factors"]))
            ]

            summary = sess.run(fetch_list, feed_dict=feed_dict)
            writer.add_summary(summary[0], float(itr))

            # use predicted feature maps
            # TODO predict boxes

            # debug logits
            # if itr ==1:
            #     hist_ph = tf.placeholder(tf.uint8, shape=[1, summary[1].shape[3]])
            #     logits_sum = tf.summary.histogram('logits_means', hist_ph)
            #
            # h_sum = sess.run([logits_sum], feed_dict={hist_ph: np.mean(summary[1],(1,2))})
            # writer.add_summary(h_sum[0], float(itr))

            gt_visuals = get_gt_visuals(blob,
                                        assign,
                                        0,
                                        pred_boxes=None,
                                        show=False)
            map_visuals = get_map_visuals(summary[1:], assign, show=False)
            images_feed_dict = get_images_feed_dict(assign, blob, gt_visuals,
                                                    map_visuals,
                                                    images_placeholders)
            # save images to tensorboard
            summary = sess.run([images_summary_op], feed_dict=images_feed_dict)
            writer.add_summary(summary[0], float(itr))

        if itr % args.save_interval == 0:
            print("saving weights")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver.save(sess, checkpoint_dir + "/" + checkpoint_name)

    iteration = (iteration + do_itr)
    if args.prefetch == "True":
        data_layer.kill()

    return iteration
Exemplo n.º 4
0
def execute_assign(args, input, saver, sess, checkpoint_dir, checkpoint_name,
                   data_layer, writer, network_heads, do_itr, assign,
                   prepped_assign, iteration, training_help):
    loss, optim, gt_placeholders, scalar_summary_op, images_summary_op, images_placeholders, mask_placeholders = prepped_assign

    if args.prefetch == "True":
        data_layer = PrefetchWrapper(data_layer.forward, args.prefetch_len,
                                     args, [assign], training_help)

    print("training on:" + str(assign))
    print("for " + str(do_itr) + " iterations")
    for itr in range(iteration, (iteration + do_itr)):
        # load batch - only use batches with content
        batch_not_loaded = True
        while batch_not_loaded:
            blob = data_layer.forward(args, [assign], training_help)
            if int(gt_placeholders[0].shape[-1]
                   ) != blob["assign0"]["gt_map0"].shape[-1] or len(
                       blob["gt_boxes"].shape) != 3:
                print("skipping queue element")
            else:
                batch_not_loaded = False

        #disable all helpers
        blob["helper"] = None

        if blob["helper"] is not None:
            input_data = np.concatenate([blob["data"], blob["helper"]], -1)
            feed_dict = {input: input_data}
        else:
            if len(args.training_help) == 1:
                feed_dict = {input: blob["data"]}
            else:
                # pad input with zeros
                input_data = np.concatenate([blob["data"], blob["data"] * 0],
                                            -1)
                feed_dict = {input: input_data}

        for i in range(len(gt_placeholders)):
            # only one assign
            feed_dict[gt_placeholders[i]] = blob["assign0"][
                "gt_map" + str(len(gt_placeholders) - i - 1)]
            feed_dict[mask_placeholders[i]] = blob["assign0"][
                "mask" + str(len(gt_placeholders) - i - 1)]

        # train step
        _, loss_fetch = sess.run([optim, loss], feed_dict=feed_dict)

        if itr % args.print_interval == 0 or itr == 1:
            print("loss at itr: " + str(itr))
            print(loss_fetch)

        if itr % args.tensorboard_interval == 0 or itr == 1:
            fetch_list = [scalar_summary_op]
            # fetch sub_predicitons
            nr_feature_maps = len(network_heads[assign["stamp_func"][0]][
                assign["stamp_args"]["loss"]])

            [
                fetch_list.append(network_heads[assign["stamp_func"][0]][
                    assign["stamp_args"]["loss"]][nr_feature_maps - (x + 1)])
                for x in range(len(assign["ds_factors"]))
            ]

            summary = sess.run(fetch_list, feed_dict=feed_dict)
            writer.add_summary(summary[0], float(itr))

            # feed one stitched image to summary op
            gt_visuals = get_gt_visuals(blob,
                                        assign,
                                        0,
                                        pred_boxes=None,
                                        show=False)
            map_visuals = get_map_visuals(summary[1:], assign, show=False)

            stitched_img = get_stitched_tensorboard_image([assign],
                                                          [gt_visuals],
                                                          [map_visuals], blob,
                                                          itr)
            stitched_img = np.expand_dims(stitched_img, 0)
            # obsolete
            #images_feed_dict = get_images_feed_dict(assign, blob, None, None, images_placeholders)
            images_feed_dict = dict()
            images_feed_dict[images_placeholders[0]] = stitched_img

            # save images to tensorboard
            summary = sess.run([images_summary_op], feed_dict=images_feed_dict)
            writer.add_summary(summary[0], float(itr))

        if itr % args.save_interval == 0:
            print("saving weights")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver.save(sess, checkpoint_dir + "/" + checkpoint_name)
            global store_dict
            if store_dict:
                print("Saving dictionary")
                dictionary = args.dict_info
                with open(os.path.join(checkpoint_dir, 'dict' + '.pickle'),
                          'wb') as handle:
                    pickle.dump(dictionary,
                                handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
                store_dict = False  # we need to save the dict only once

    iteration = (iteration + do_itr)
    if args.prefetch == "True":
        data_layer.kill()

    return iteration
def execute_assign(args, input_placeholder, saver, sess, checkpoint_dir,
                   checkpoint_name, data_layer, writer, network_heads, do_itr,
                   assign, prepped_assign, iteration, training_help):
    loss, optim, gt_placeholders, scalar_summary_op, images_summary_op, images_placeholders, mask_placeholders = prepped_assign

    if args.prefetch == "True":
        data_layer[0] = PrefetchWrapper(data_layer[0].forward,
                                        args.prefetch_len, args, [assign],
                                        training_help)

    print("training on:" + str(assign))
    print("for " + str(do_itr) + " iterations")
    for itr in range(iteration, (iteration + do_itr)):

        # run a training batch
        loss_fetch, feed_dict, blob = run_batch_assign(
            data_layer, args, assign, training_help, input_placeholder,
            gt_placeholders, mask_placeholders, sess, optim, loss)

        if itr % args.print_interval == 0 or itr == 1:
            print("loss at itr: " + str(itr))
            print(loss_fetch)

        if itr % args.tensorboard_interval == 0 or itr == 1:
            fetch_list = [scalar_summary_op[0]]
            # fetch sub_predicitons
            nr_feature_maps = len(network_heads[assign["stamp_func"][0]][
                assign["stamp_args"]["loss"]])

            [
                fetch_list.append(network_heads[assign["stamp_func"][0]][
                    assign["stamp_args"]["loss"]][nr_feature_maps - (x + 1)])
                for x in range(len(assign["ds_factors"]))
            ]

            summary = sess.run(fetch_list, feed_dict=feed_dict)
            writer.add_summary(summary[0], float(itr))

            # feed one stitched image to summary op
            gt_visuals = get_gt_visuals(blob,
                                        assign,
                                        0,
                                        pred_boxes=None,
                                        show=False)
            map_visuals = get_map_visuals(summary[1:], assign, show=False)

            stitched_img = get_stitched_tensorboard_image([assign],
                                                          [gt_visuals],
                                                          [map_visuals], blob,
                                                          itr)
            stitched_img = np.expand_dims(stitched_img, 0)
            # obsolete
            #images_feed_dict = get_images_feed_dict(assign, blob, None, None, images_placeholders)
            images_feed_dict = dict()
            images_feed_dict[images_placeholders[0]] = stitched_img

            # save images to tensorboard
            summary = sess.run([images_summary_op[0]],
                               feed_dict=images_feed_dict)
            writer.add_summary(summary[0], float(itr))

        if itr % args.validation_loss_task == 0:
            # approximate validation loss
            val_loss = 0
            for i in range(args.validation_loss_task_nr_batch):
                loss_fetch, feed_dict, blob = run_batch_assign(
                    data_layer,
                    args,
                    assign,
                    training_help,
                    input_placeholder,
                    gt_placeholders,
                    mask_placeholders,
                    sess,
                    optim,
                    loss,
                    valid=1)
                val_loss += loss_fetch[0]

            val_loss = val_loss / args.validation_loss_task_nr_batch
            print("Validation loss estimate at itr " + str(itr) + ": " +
                  str(val_loss))

            # post to tensorboard
            fetch_list = [scalar_summary_op[1]]
            # fetch sub_predicitons
            nr_feature_maps = len(network_heads[assign["stamp_func"][0]][
                assign["stamp_args"]["loss"]])

            [
                fetch_list.append(network_heads[assign["stamp_func"][0]][
                    assign["stamp_args"]["loss"]][nr_feature_maps - (x + 1)])
                for x in range(len(assign["ds_factors"]))
            ]

            summary = sess.run(fetch_list, feed_dict=feed_dict)
            writer.add_summary(summary[0], float(itr))

            # feed one stitched image to summary op
            gt_visuals = get_gt_visuals(blob,
                                        assign,
                                        0,
                                        pred_boxes=None,
                                        show=False)
            map_visuals = get_map_visuals(summary[1:], assign, show=False)

            stitched_img = get_stitched_tensorboard_image([assign],
                                                          [gt_visuals],
                                                          [map_visuals], blob,
                                                          itr)
            stitched_img = np.expand_dims(stitched_img, 0)
            # obsolete
            # images_feed_dict = get_images_feed_dict(assign, blob, None, None, images_placeholders)
            images_feed_dict = dict()
            images_feed_dict[images_placeholders[1]] = stitched_img

            # save images to tensorboard
            summary = sess.run([images_summary_op[1]],
                               feed_dict=images_feed_dict)
            writer.add_summary(summary[0], float(itr))

        if itr % args.validation_loss_final == 0:
            print("do full validation")

        if itr % args.save_interval == 0:
            print("saving weights")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver.save(sess, checkpoint_dir + "/" + checkpoint_name)
            global store_dict
            if store_dict:
                print("Saving dictionary")
                dictionary = args.dict_info
                with open(os.path.join(checkpoint_dir, 'dict' + '.pickle'),
                          'wb') as handle:
                    pickle.dump(dictionary,
                                handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
                store_dict = False  # we need to save the dict only once

    iteration = (iteration + do_itr)
    if args.prefetch == "True":
        data_layer.kill()

    return iteration