def config_session(self):
        if self.TFGPUinference:
            tf.reset_default_graph()
            self.inputs = tf.placeholder(
                tf.float32, shape=[self.dlc_cfg.batch_size, None, None, 3])
            net_heads = pose_net(self.dlc_cfg).inference(self.inputs)
            self.outputs = [net_heads["pose"]]
            restorer = tf.train.Saver()
            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())

            # Restore variables from disk.
            restorer.restore(self.sess, self.dlc_cfg.init_weights)
        else:
            TF.reset_default_graph()
            self.inputs = TF.placeholder(
                tf.float32, shape=[self.dlc_cfg.batch_size, None, None, 3])
            net_heads = pose_net(self.dlc_cfg).test(self.inputs)
            self.outputs = [net_heads["part_prob"]]
            if self.dlc_cfg.location_refinement:
                self.outputs.append(net_heads["locref"])
            if ("multi-animal" in self.dlc_cfg.dataset_type
                ) and self.dlc_cfg.partaffinityfield_predict:
                print("Activating extracting of PAFs")
                self.outputs.append(net_heads["pairwise_pred"])
            restorer = TF.train.Saver()
            self.sess = TF.Session()
            self.sess.run(TF.global_variables_initializer())
            self.sess.run(TF.local_variables_initializer())

            # Restore variables from disk.
            restorer.restore(self.sess, self.dlc_cfg.init_weights)
Ejemplo n.º 2
0
def initialize_resnet(dlc_cfg, nx_in, ny_in, allow_growth=True):
    from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net

    TF.reset_default_graph()
    inputs = TF.placeholder(TF.float32, shape=[1, nx_in, ny_in, 3])
    pn = pose_net(dlc_cfg)

    # extract resnet outputs
    net, end_points = pn.extract_features(inputs)

    # restore from snapshot
    if 'snapshot' in dlc_cfg.init_weights:
        print('restoring from snapshot')
        variables_to_restore = slim.get_variables_to_restore()
    else:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])

    restorer = TF.train.Saver(variables_to_restore)

    # initialize tf session
    config_TF = TF.ConfigProto()
    config_TF.gpu_options.allow_growth = allow_growth
    sess = TF.Session(config=config_TF)

    # initialize weights
    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # restore the weights from disk
    restorer.restore(sess, dlc_cfg.init_weights)

    return sess, net, inputs
Ejemplo n.º 3
0
def setup_pose_prediction(cfg):
    TF.reset_default_graph()
    if not cfg.using_z_slices:
        # 2d input; default
        inputs = TF.placeholder(tf.float32,
                                shape=[cfg.batch_size, None, None, 3])
    else:
        # volume, i.e. z-slices input
        print("Setting up Volume-based evaluation")
        inputs = TF.placeholder(tf.float32,
                                shape=[cfg.batch_size, None, None, None, 3])

    net_heads = pose_net(cfg).test(inputs)
    outputs = [net_heads["part_prob"]]
    if cfg.location_refinement:
        outputs.append(net_heads["locref"])

    if ("multi-animal" in cfg.dataset_type) and cfg.partaffinityfield_predict:
        print("Activating extracting of PAFs")
        outputs.append(net_heads["pairwise_pred"])

    restorer = TF.train.Saver()
    sess = TF.Session()
    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    return sess, inputs, outputs
Ejemplo n.º 4
0
def setup_pose_prediction(cfg,allow_growth=False):
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[cfg["batch_size"], None, None, 3])
    net_heads = pose_net(cfg).test(inputs)
    outputs = [net_heads["part_prob"]]
    if cfg["location_refinement"]:
        outputs.append(net_heads["locref"])

    if ("multi-animal" in cfg["dataset_type"]) and cfg["partaffinityfield_predict"]:
        print("Activating extracting of PAFs")
        outputs.append(net_heads["pairwise_pred"])

    restorer = TF.train.Saver()

    if allow_growth == True:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()
    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg["init_weights"])

    return sess, inputs, outputs
Ejemplo n.º 5
0
def get_resnet_outsize(videofile_path, dlc_cfg):
    from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net
    #%%
    clip = VideoFileClip(str(videofile_path))
    #%%
    ny_in, nx_in = clip.size

    # %%
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[1, nx_in, ny_in, 3])
    pn = pose_net(dlc_cfg)
    # extract resnet outputs
    net, end_points = pn.extract_features(inputs)
    nx_out, ny_out = net.shape.as_list()[1:3]
    return nx_out, ny_out
Ejemplo n.º 6
0
def setup_GPUpose_prediction(cfg):
    tf.reset_default_graph()
    inputs = tf.placeholder(tf.float32, shape=[cfg['batch_size'], None, None, 3])
    net_heads = pose_net(cfg).inference(inputs)
    outputs = [net_heads["pose"]]

    restorer = tf.train.Saver()
    sess = tf.Session()

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg['init_weights'])

    return sess, inputs, outputs
Ejemplo n.º 7
0
def setup_pose_prediction(cfg):
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[cfg.batch_size, None, None, 3])
    net_heads = pose_net(cfg).test(inputs)
    outputs = [net_heads['part_prob']]
    if cfg.location_refinement:
        outputs.append(net_heads['locref'])
    restorer = TF.train.Saver()
    sess = TF.Session()
    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    return sess, inputs, outputs
Ejemplo n.º 8
0
def setup_pose_prediction(cfg):
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[cfg['batch_size'], None, None, 3])
    net_heads = pose_net(cfg).test(inputs)
    outputs = [net_heads["part_prob"]]
    if cfg['location_refinement']:
        outputs.append(net_heads["locref"])

    if ("multi-animal" in cfg['dataset_type']) and cfg['partaffinityfield_predict']:
        print("Activating extracting of PAFs")
        outputs.append(net_heads["pairwise_pred"])

    restorer = TF.train.Saver()
    sess = TF.Session()
    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg['init_weights'])

    return sess, inputs, outputs
Ejemplo n.º 9
0
def setup_GPUpose_prediction(cfg,allow_growth=False):
    tf.reset_default_graph()
    inputs = tf.placeholder(tf.float32, shape=[cfg["batch_size"], None, None, 3])
    net_heads = pose_net(cfg).inference(inputs)
    outputs = [net_heads["pose"]]

    restorer = tf.train.Saver()

    if allow_growth == True:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg["init_weights"])

    return sess, inputs, outputs
Ejemplo n.º 10
0
    def _compute_pred_dims(self):
        """Compute output dims of dgp prediction layer by pushing fake data through network."""
        from deepgraphpose.models.fitdgp_util import dgp_prediction_layer
        from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net

        TF.reset_default_graph()

        nc = 3
        inputs = TF.placeholder(TF.float32, shape=[None, self.nx_in, self.ny_in, nc])

        pn = pose_net(self.dlc_config)
        conv_inputs, end_points = pn.extract_features(inputs)

        x = dgp_prediction_layer(
            None, None, self.dlc_config, conv_inputs, 'confidencemap', self.nj, 0,
            nc, 1)

        sess = TF.Session(config=TF.ConfigProto())
        sess.run(TF.global_variables_initializer())
        sess.run(TF.local_variables_initializer())
        feed_dict = {inputs: np.zeros([1, self.nx_in, self.ny_in, nc])}
        x_np = sess.run(x, feed_dict)

        return x_np.shape[1], x_np.shape[2]
Ejemplo n.º 11
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=False,
):
    start_path = os.getcwd()
    os.chdir(
        str(Path(config_yaml).parents[0])
    )  # switch to folder of config_yaml (for logging)

    setup_logging()

    cfg = load_config(config_yaml)
    if cfg["optimizer"] != "adam":
        print(
            "Setting batchsize to 1! Larger batchsize not supported for this loader:",
            cfg["dataset_type"],
        )
        cfg["batch_size"] = 1

    if (
        cfg["partaffinityfield_predict"] and "multi-animal" in cfg["dataset_type"]
    ):  # the PAF code currently just hijacks the pairwise net stuff (for the batch feeding via Batch.pairwise_targets: 5)
        print("Activating limb prediction...")
        cfg["pairwise_predict"] = True

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()
    net_type = cfg["net_type"]

    if "snapshot" in Path(cfg["init_weights"]).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"]
            )
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"]
            )
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "")
                + "/ExponentialMovingAverage": var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg["log_dir"], sess.graph)
    learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    restorer.restore(sess, cfg["init_weights"])
    if maxiters == None:
        max_iter = int(cfg["multi_step"][-1][1])
    else:
        max_iter = min(int(cfg["multi_step"][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg["display_iters"]))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg["save_iters"]))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
    lr_gen = LearningRate(cfg)
    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameters:")
    print(cfg)
    print("Starting multi-animal training....")
    for it in range(max_iter + 1):
        if "efficientnet" in net_type:
            dict = {tstep: it}
            current_lr = sess.run(learning_rate, feed_dict=dict)
        else:
            current_lr = lr_gen.get_lr(it)
            dict = {learning_rate: current_lr}

        # [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],feed_dict={learning_rate: current_lr})
        [_, alllosses, loss_val, summary] = sess.run(
            [train_op, losses, total_loss, merged_summaries], feed_dict=dict
        )

        partloss += alllosses["part_loss"]  # scoremap loss
        if cfg["location_refinement"]:
            locrefloss += alllosses["locref_loss"]
        if cfg["pairwise_predict"]:  # paf loss
            pwloss += alllosses["pairwise_loss"]

        cumloss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            logging.info(
                "iteration: {} loss: {} scmap loss: {} locref loss: {} limb loss: {} lr: {}".format(
                    it,
                    "{0:.4f}".format(cumloss / display_iters),
                    "{0:.4f}".format(partloss / display_iters),
                    "{0:.4f}".format(locrefloss / display_iters),
                    "{0:.4f}".format(pwloss / display_iters),
                    current_lr,
                )
            )

            lrf.write(
                "iteration: {}, loss: {}, scmap loss: {}, locref loss: {}, limb loss: {}, lr: {}\n".format(
                    it,
                    "{0:.4f}".format(cumloss / display_iters),
                    "{0:.4f}".format(partloss / display_iters),
                    "{0:.4f}".format(locrefloss / display_iters),
                    "{0:.4f}".format(pwloss / display_iters),
                    current_lr,
                )
            )

            cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg["snapshot_prefix"]
            saver.save(sess, model_name, global_step=it)

    lrf.close()

    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))
Ejemplo n.º 12
0
def train(config_yaml, displayiters, saveiters, maxiters, max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    # sess = TF.Session()
    sess = TF.Session(config=TF.ConfigProto(device_count={'GPU': 0}))
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("\n\nMax_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("\nTraining parameter:\n")
    pprint.pprint(cfg)
    print("\n\nStarting training....")
    start = time.time()
    print("\nStarting time of training:  {} \n".format(
        datetime.datetime.now()))
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0:
            end = time.time()
            hours, rem = divmod(end - start, 3600)
            time_hours, time_rem = divmod(end, 3600)
            minutes, seconds = divmod(rem, 60)
            time_mins, _ = divmod(time_rem, 60)
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info(
                "iteration: {}/{},    loss:  {:.4f},    lr: {},  |   Elapsed Time:  {:0>2}:{:0>2}:{:05.2f},    Time:  {}"
                .format(it, max_iter, average_loss, current_lr, int(hours),
                        int(minutes), seconds,
                        datetime.datetime.now().strftime("%H:%M")))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
Ejemplo n.º 13
0
def train(config_yaml, displayiters, saveiters, max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    max_iter = int(cfg.multi_step[-1][1])

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
Ejemplo n.º 14
0
def store_resnet_output(task,
                        date,
                        shuffle,
                        overwrite_snapshot=None,
                        allow_growth=True,
                        videofile_path=None,
                        resnet_output_dir=None):
    from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net
    """
    task = 'ibl1'
    date = '2020-01-25'
    shuffle = 1
    overwrite_snapshot = 5000
    ibl_chunk_path = '/data/libraries/deepgraphpose/etc/lab_ekb/debug_va_semi_pipeline/run_long_video_aqweight/movies'
    videofile_path = Path(ibl_chunk_path) / 'movie_chunk_00.mp4'
    resnet_output_dir = videofile_path.parent / videofile_path.stem
    allow_growth = True
    """
    #%%
    if isinstance(videofile_path, str):
        videofile_path = Path(videofile_path)
    else:
        pass
        # assert isinstance(videofile_path, Path)
    if isinstance(resnet_output_dir, str):
        resnet_output_dir = Path(resnet_output_dir)
    else:
        pass
        #assert isinstance(resnet_output_dir, Path)

    #%%
    data_info = DataLoader(task)
    cfg = get_model_config(task,
                           data_info.model_data_dir,
                           scorer=data_info.scorer,
                           date=date)
    #%%
    dlc_cfg = get_train_config(cfg, shuffle)
    # dlc_cfg_init = edict(dlc_cfg)
    #%%
    trainingsnapshot_name, trainingsnapshot, dlc_cfg = load_dlc_snapshot(
        dlc_cfg, overwrite_snapshot=overwrite_snapshot)
    if not overwrite_snapshot == None:
        assert trainingsnapshot == overwrite_snapshot
    # dlc_cfg_init = edict(dlc_cfg)
    #%% Update dlc_cfg files just to init network
    dlc_cfg["batch_size"] = 1
    dlc_cfg["num_outputs"] = cfg.get("num_outputs", 1)
    dlc_cfg["deterministic"] = True
    # %%
    # Load data
    if videofile_path is None:
        videofile_path = Path(data_info.videofile_path)

    #%%

    cap = cv2.VideoCapture(str(videofile_path))
    nframes = int(cap.get(7))
    #%%
    # We want to pass all frames through network
    nx_in, ny_in = int(cap.get(4)), int(cap.get(3))

    frames = np.zeros((nframes, nx_in, ny_in, 3),
                      dtype="ubyte")  # this keeps all frames in a batch
    pbar = tqdm(total=nframes)
    counter = 0
    step = nframes // 3  #max(10, int(nframes / 100))

    while cap.isOpened():
        if counter % step == 0:
            pbar.update(step)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames[counter] = img_as_ubyte(frame)
            counter = counter + 1
        else:
            print("the end")
            break
    pbar.close()
    # read all frames
    assert counter == nframes

    # %%
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[1, nx_in, ny_in, 3])
    pn = pose_net(dlc_cfg)
    # extract freatures using resnet
    net, end_points = pn.extract_features(inputs)
    # heads = pn.prediction_layers(net, end_points)
    # %%
    # always restore from snapshot do not restore from IBL
    if trainingsnapshot == 0:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])
    else:
        variables_to_restore = slim.get_variables_to_restore()
    restorer = TF.train.Saver(variables_to_restore)

    # Init session
    config_TF = TF.ConfigProto()
    config_TF.gpu_options.allow_growth = allow_growth
    sess = TF.Session(config=config_TF)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # %%
    # Restore the one variable from disk
    restorer.restore(sess, dlc_cfg.init_weights)

    #%%
    if resnet_output_dir is None:
        resnet_output_dir = Path(dlc_cfg.init_weights).parent
        print(resnet_output_dir)
    #%%
    resnet_outdir = resnet_output_dir / "resnet_output_mat" / (
        "{}".format(trainingsnapshot_name))
    if not resnet_outdir.exists():
        os.makedirs(resnet_outdir)
    #%%
    for ii in range(nframes):
        if ii % 10 == 0:
            print("iter {}/{}".format(ii, nframes))
        ff = frames[ii, :, :, :]
        ff = np.expand_dims(ff, axis=0)
        [net_output] = sess.run([net], feed_dict={inputs: ff})
        # net_heads = sess.run(heads, feed_dict={inputs: ff})
        ss = resnet_outdir / ("resnet_output_{:03d}.mat".format(ii))
        sio.savemat(str(ss), {"net_output": net_output})
    print("Stored resnet outputs in:\n{}".format(resnet_outdir))
    #%%
    sess.close()
    return
Ejemplo n.º 15
0
def pass_video_through_resnet(videofile_path,
                              dlc_cfg,
                              allow_growth=True,
                              indices=None,
                              nc=None,
                              step_pbar=500):
    """
    Pass frames through nextwork
    can pass selected frames
    :param videofile_path:
    :param dlc_cfg:
    :param allow_growth:
    :param indices:
    :param nc:
    :return:
    """
    #%%
    from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net

    clip = VideoFileClip(str(videofile_path))
    #%%
    ny_in, nx_in = clip.size
    if indices is None:
        #%%
        nframes = clip.duration * clip.fps
        nframes_fsec = nframes - int(nframes)
        if (nframes_fsec < 1 / clip.fps):
            nframes = np.floor(nframes).astype('int')
        else:
            nframes = np.ceil(nframes).astype('int')
            print('Warning. Check the number of frames')
            # raise Exception('You shouldn''t be here. Check video reader')
        indices = np.arange(nframes)
    else:
        nframes = len(indices)

    # %%
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[1, nx_in, ny_in, 3])
    pn = pose_net(dlc_cfg)
    # extract resnet outputs
    net, end_points = pn.extract_features(inputs)
    # heads = pn.prediction_layers(net, end_points)
    # %%
    # restore from snapshot
    if 'snapshot' in dlc_cfg.init_weights:
        variables_to_restore = slim.get_variables_to_restore()
    else:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])

    restorer = TF.train.Saver(variables_to_restore)

    # Init session
    config_TF = TF.ConfigProto()
    config_TF.gpu_options.allow_growth = allow_growth
    sess = TF.Session(config=config_TF)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # %%
    # Restore the one variable from disk
    restorer.restore(sess, dlc_cfg.init_weights)

    #%%
    nx_out, ny_out = net.shape.as_list()[1:3]
    #%%
    if nc is None:
        # load all channels
        nchannels = 2048
        nc = 2048
    elif isinstance(nc, int):
        nchannels = nc
    elif isinstance(nc, np.ndarray):
        nchannels = len(nc)
    else:
        raise Exception('Check nc')

    # Here we fix a large #
    resnet_outputs = np.zeros((nframes, nx_out, ny_out, nchannels),
                              dtype="float32")

    #%%
    pbar = tqdm(total=nframes, desc='Read video frames')
    step = int(max(step_pbar, nframes // 3))

    for counter, index in enumerate(indices):
        ff = img_as_ubyte(clip.get_frame(index * 1. / clip.fps))
        [net_output
         ] = sess.run([net], feed_dict={inputs: ff[None, :, :, :]
                                        })  # 1 x nx_out x ny_out x nchannels
        if isinstance(nc, int):
            resnet_outputs[counter, :, :] = net_output[0, :, :, :nc]
        elif isinstance(nc, np.ndarray):
            resnet_outputs[counter, :, :] = net_output[0, :, :, nc]
        else:
            raise Exception('Not proper resnet channel selection')

        if (counter % step == 0) or (counter == nframes - 1):
            pbar.update(min(counter, step))

    pbar.close()
    #%%
    assert counter == (nframes - 1)
    clip.close()
    sess.close()
    #%%
    return resnet_outputs
def store_test_resnet_output_chunks(dlc_cfg,
                                    nc=200,
                                    chunk_size=1000,
                                    allow_growth=True,
                                    debug_key=""):
    # debug_key = "nt_{}".format(nt_chunk)
    #
    # %%
    from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net
    import tensorflow.contrib.slim as slim
    from tqdm import tqdm
    from skimage.util import img_as_ubyte
    from PoseDataLoader import TestDataLoader

    # %%
    clip = VideoFileClip(str(dlc_cfg.video_path))
    ny_raw, nx_raw = clip.size
    fps = clip.fps
    # %%
    nframes = clip.duration * clip.fps
    nframes_fsec = nframes - int(nframes)
    # %%
    if (nframes_fsec < 1 / clip.fps):
        nframes = np.floor(nframes).astype('int')
    else:
        nframes = np.ceil(nframes).astype('int')
        print('Warning. Check the number of frames')
    # %%
    # Build graph to pass frames through resnet
    TF.reset_default_graph()
    inputs = TF.placeholder(tf.float32, shape=[1, nx_raw, ny_raw, 3])
    pn = pose_net(dlc_cfg)
    net, end_points = pn.extract_features(inputs)
    # heads = pn.prediction_layers(net, end_points)
    # %%
    # restore from snapshot
    if 'snapshot' in dlc_cfg.init_weights:
        variables_to_restore = slim.get_variables_to_restore()
    else:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])

    restorer = TF.train.Saver(variables_to_restore)

    # Init session
    config_TF = TF.ConfigProto()
    config_TF.gpu_options.allow_growth = allow_growth
    sess = TF.Session(config=config_TF)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())
    # %%
    # Restore the one variable from disk
    restorer.restore(sess, dlc_cfg.init_weights)
    # %%
    nx_out, ny_out = net.shape.as_list()[1:3]
    # %%
    if nc is None:
        # load all channels
        nchannels = 2048
        nc = 2048
    elif isinstance(nc, int):
        nchannels = nc
    elif isinstance(nc, np.ndarray):
        nchannels = len(nc)
    else:
        raise Exception('Check nc')
    # %%

    num_chunks_tvideo = int(np.ceil(nframes / chunk_size))
    print('Video is split in {} resnet_out files'.format(num_chunks_tvideo))

    # %% Make test dataset
    test_data = TestDataLoader(dlc_cfg, debug_key=debug_key)
    # %%
    if not test_data.video_data_chunks_dir.exists():
        os.makedirs(test_data.video_data_chunks_dir)

    if not test_data.resnet_output_chunks_dir.exists():
        os.makedirs(test_data.resnet_output_chunks_dir)

    # %%
    for chunk_id, video_start in enumerate(np.arange(0, nframes, chunk_size)):
        video_end = min(video_start + chunk_size, nframes)

        nvideoframes = video_end - video_start

        # %% Make movie file:
        start_sec = np.round(video_start / fps, 5)
        end_sec = np.round(video_end / fps, 5)
        bonus = (nvideoframes - (end_sec - start_sec) * fps) / fps
        if bonus < 0:
            end_sec += 2 * bonus

        mini_clip = clip.subclip(t_start=start_sec, t_end=end_sec)
        n_frames = sum(1 for x in mini_clip.iter_frames())
        if not (n_frames == nvideoframes):
            raise Exception('what for {}'.format(chunk_id))

        video_fname = test_data.get_video_data_chunks_fname(chunk_id)
        mini_clip.write_videofile(str(video_fname))
        # print('Wrote file:\n {}'.format(video_fname))
        # %%
        # Make resnet output file:
        indices = np.arange(video_start, video_end)
        resnet_outputs = np.zeros((nvideoframes, nx_out, ny_out, nchannels),
                                  dtype="float32")
        pbar = tqdm(total=nvideoframes,
                    desc='Pass through resnet chunk {}'.format(chunk_id))
        step = nvideoframes // 3
        for counter, index in enumerate(indices):
            ff = img_as_ubyte(clip.get_frame(index * 1. / clip.fps))
            [net_output] = sess.run([net],
                                    feed_dict={inputs: ff[None, :, :, :]})
            if isinstance(nc, int):
                resnet_outputs[counter, :, :] = net_output[0, :, :, :nc]
            elif isinstance(nc, np.ndarray):
                resnet_outputs[counter, :, :] = net_output[0, :, :, nc]
            else:
                raise Exception('Not proper resnet channel selection')

            if (counter % step == 0) or (counter == nvideoframes):
                pbar.update(min(counter, step))
        pbar.close()

        resnet_fname = test_data.get_resnet_output_chunks_fname(chunk_id)
        with h5py.File(str(resnet_fname), 'w') as f:
            f.create_dataset("resnet_out", data=resnet_outputs)
            # f.create_dataset("resnet_idx", data=frames_in_chunk)
            # f.create_dataset("pv", data=pv_chunk)
            # f.create_dataset("ph", data=ph_chunk)
            f.create_dataset("start", data=video_start)
            f.create_dataset("stop", data=video_end)

    # print('Stored resnet output in:\n{}'.format(
    #    chunk_id, str(image_path)))

    return
def store_prediction_layer(task,
                           date,
                           shuffle,
                           overwrite_snapshot=None,
                           allow_growth=True):
    #%%
    from deeplabcut.pose_estimation_tensorflow.dataset.factory import (
        create as create_dataset, )
    from deeplabcut.pose_estimation_tensorflow.dataset.pose_dataset import Batch
    from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net
    #
    from deeplabcut.pose_estimation_tensorflow.train import (
        get_batch_spec,
        setup_preloading,
        start_preloading,
        get_optimizer,
        LearningRate,
    )
    from deepgraphpose.PoseDataLoader import DataLoader
    from deepgraphpose.utils_model import load_dlc_snapshot, get_train_config, \
        get_model_config

    #%%
    data_info = DataLoader(task)

    #%%
    cfg = get_model_config(task,
                           data_info.model_data_dir,
                           scorer=data_info.scorer,
                           date=date)

    #%%
    dlc_cfg = get_train_config(cfg, shuffle)
    trainingsnapshot_name, trainingsnapshot, dlc_cfg = load_dlc_snapshot(
        dlc_cfg, overwrite_snapshot=overwrite_snapshot)
    #%%
    # Batch is a class filled with indices
    TF.reset_default_graph()
    # create dataset
    dataset = create_dataset(dlc_cfg)
    #%%
    # train: inputs, part_score_targets, part_score_weights, locref_mask
    batch_spec = get_batch_spec(dlc_cfg)
    # queing
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    # init graph
    pn = pose_net(dlc_cfg)
    # extracts features, and runs it through a covnet,
    inputs = batch[Batch.inputs]
    net, end_points = pn.extract_features(inputs)
    # net is the input to the conv2d_transpose layer
    heads = pn.prediction_layers(net, end_points)

    #%%
    multi_class_labels = batch[Batch.part_score_targets]
    weigh_part_predictions = dlc_cfg.weigh_part_predictions
    part_score_weights = batch[
        Batch.part_score_weights] if weigh_part_predictions else 1.0
    #%%
    from deeplabcut.pose_estimation_tensorflow.nnet import losses

    #%%
    def add_part_loss(multi_class_labels, logits, part_score_weights):
        return tf.losses.sigmoid_cross_entropy(multi_class_labels, logits,
                                               part_score_weights)

    loss = {}
    logits = heads['part_pred']
    loss['part_loss'] = add_part_loss(multi_class_labels, logits,
                                      part_score_weights)

    total_loss = loss['part_loss']
    if dlc_cfg.intermediate_supervision:
        logits_intermediate = heads['part_loss_interm']
        loss['part_loss_interm'] = add_part_loss(multi_class_labels,
                                                 logits_intermediate,
                                                 part_score_weights)
        total_loss = total_loss + loss['part_loss_interm']

    if dlc_cfg.location_refinement:
        locref_pred = heads['locref']
        locref_targets = batch[Batch.locref_targets]
        locref_weights = batch[Batch.locref_mask]

        loss_func = losses.huber_loss if dlc_cfg.locref_huber_loss else tf.losses.mean_squared_error
        loss['locref_loss'] = dlc_cfg.locref_loss_weight * loss_func(
            locref_targets, locref_pred, locref_weights)
        total_loss = total_loss + loss['locref_loss']

    # loss['total_loss'] = slim.losses.get_total_loss(add_regularization_losses=params.regularize)
    loss['total_loss'] = total_loss

    #%%
    for k, t in loss.items():
        TF.summary.scalar(k, t)
    TF.summary.merge_all()

    #%%
    # restore from snapshot
    if trainingsnapshot == 0:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])
    else:
        variables_to_restore = slim.get_variables_to_restore()
    restorer = TF.train.Saver(variables_to_restore)
    #%% Init session
    config_TF = TF.ConfigProto()
    config_TF.gpu_options.allow_growth = True
    sess = TF.Session(config=config_TF)

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    TF.summary.FileWriter(dlc_cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, dlc_cfg)
    #%%
    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore the one variable from disk
    restorer.restore(sess, dlc_cfg.init_weights)
    print('Restored variables from\n{}\n'.format(dlc_cfg.init_weights))

    #%%
    lr_gen = LearningRate(dlc_cfg)
    #%%
    dlc_params_outdir = Path(
        dlc_cfg.init_weights).parent / 'dlc_params_mat' / '{}'.format(
            trainingsnapshot_name)
    if not os.path.isdir(dlc_params_outdir):
        assert Path(dlc_cfg.init_weights).parent
        os.makedirs(dlc_params_outdir)
    print(dlc_params_outdir)

    #%%
    biases = [
        v for v in tf.global_variables()
        if v.name == "pose/part_pred/block4/biases:0"
    ][0]
    weights = [
        v for v in tf.global_variables()
        if v.name == "pose/part_pred/block4/weights:0"
    ][0]

    if dlc_cfg.location_refinement:
        biases_locref = [
            v for v in tf.global_variables()
            if v.name == "pose/locref_pred/block4/biases:0"
        ][0]
        weights_locref = [
            v for v in tf.global_variables()
            if v.name == "pose/locref_pred/block4/weights:0"
        ][0]

    # locref_pred
    #%%
    current_lr = lr_gen.get_lr(0)
    if dlc_cfg.location_refinement:
        [_, biases_out, weights_out, bias_locref_out,
         weight_locref_out] = sess.run(
             [train_op, biases, weights, biases_locref, weights_locref],
             feed_dict={learning_rate: current_lr})

        ss = os.path.join(dlc_params_outdir, 'dlc_params.mat')
        sio.savemat(
            ss, {
                'weight': weights_out,
                'bias': biases_out,
                'weight_locref': weight_locref_out,
                'bias_locref': bias_locref_out
            })
    else:
        [_, biases_out,
         weights_out] = sess.run([train_op, biases, weights],
                                 feed_dict={learning_rate: current_lr})

        ss = os.path.join(dlc_params_outdir, 'dlc_params.mat')
        sio.savemat(ss, {'weight': weights_out, 'bias': biases_out})
    print('\nStored output in\n{}\n'.format(str(ss)))
    sess.close()
    coord.request_stop()
    return
Ejemplo n.º 18
0
def train(config_yaml,
          displayiters,
          saveiters,
          maxiters,
          max_to_keep=5,
          keepdeconvweights=True):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    if cfg.dataset_type == 'default' or cfg.dataset_type == 'tensorpack' or cfg.dataset_type == 'deterministic':
        print(
            "Switching batchsize to 1, as default/tensorpack/deterministic loaders do not support batches >1. Use imgaug loader."
        )

        cfg['batch_size'] = 1  #in case this was edited for analysis.-

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    if 'snapshot' in Path(cfg.init_weights).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", cfg.net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", cfg.net_type)
        #loading backbone from ResNet, MobileNet etc.
        if 'resnet' in cfg.net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["resnet_v1"])
        elif 'mobilenet' in cfg.net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"])
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = TF.Session(config=config)
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
Ejemplo n.º 19
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=False,
):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  # switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    net_type = cfg['net_type']
    if cfg['dataset_type'] in ("scalecrop", "tensorpack", "deterministic"):
        print(
            "Switching batchsize to 1, as tensorpack/scalecrop/deterministic loaders do not support batches >1. Use imgaug/default loader."
        )
        cfg["batch_size"] = 1  # in case this was edited for analysis.-

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    if "snapshot" in Path(cfg['init_weights']).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"])
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"])
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "") +
                '/ExponentialMovingAverage': var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth == True:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg['log_dir'], sess.graph)

    if cfg.get("freezeencoder", False):
        if 'efficientnet' in net_type:
            print("Freezing ONLY supported MobileNet/ResNet currently!!")
            learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

        print("Freezing encoder...")
        learning_rate, _, train_op = get_optimizer_with_freeze(total_loss, cfg)
    else:
        learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg['init_weights'])
    if maxiters == None:
        max_iter = int(cfg['multi_step'][-1][1])
    else:
        max_iter = min(int(cfg['multi_step'][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg['display_iters']))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg['save_iters']))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        if 'efficientnet' in net_type:
            dict = {tstep: it}
            current_lr = sess.run(learning_rate, feed_dict=dict)
        else:
            current_lr = lr_gen.get_lr(it)
            dict = {learning_rate: current_lr}

        [_, loss_val, summary] = sess.run(
            [train_op, total_loss, merged_summaries],
            feed_dict=dict,
        )
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg['snapshot_prefix']
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))
Ejemplo n.º 20
0
def train(config_yaml,
          displayiters,
          saveiters,
          maxiters,
          max_to_keep=5,
          projection_matrices=None,
          multiview_step=None,
          snapshot_index=None):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    cfg['projection_matrices'] = projection_matrices
    cfg['multiview_step'] = multiview_step
    # at this step, jittering the image sizes won't help
    # also, if we jitter the sizes then we would have to undo the jitter before projecting to 3D, so we may as well keep the image size constant
    if multiview_step == 2:
        cfg.global_scale = 1.0
        cfg.scale_jitter_lo = 1.0
        cfg.scale_jitter_up = 1.0
        # also found best results with this optimizer and lr
        print('switching to hardcoded Adam optimizer for this step')
        cfg.optimizer = 'adam'
        cfg.adam_lr = 0.0001

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    if snapshot_index is None:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])
    else:
        variables_to_restore = slim.get_variables_to_restore(exclude=[
            op.name for op in tf.global_variables(scope='.*reweighting.*')
        ])
        cfg.init_weights = os.path.join(os.path.dirname(config_yaml),
                                        'snapshot-%d' % snapshot_index)

    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))