Example #1
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=False,
):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  # switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    net_type = cfg['net_type']
    if cfg['dataset_type'] in ("scalecrop", "tensorpack", "deterministic"):
        print(
            "Switching batchsize to 1, as tensorpack/scalecrop/deterministic loaders do not support batches >1. Use imgaug/default loader."
        )
        cfg["batch_size"] = 1  # in case this was edited for analysis.-

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    if "snapshot" in Path(cfg['init_weights']).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"])
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"])
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "") +
                '/ExponentialMovingAverage': var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth == True:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg['log_dir'], sess.graph)

    if cfg.get("freezeencoder", False):
        if 'efficientnet' in net_type:
            print("Freezing ONLY supported MobileNet/ResNet currently!!")
            learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

        print("Freezing encoder...")
        learning_rate, _, train_op = get_optimizer_with_freeze(total_loss, cfg)
    else:
        learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg['init_weights'])
    if maxiters == None:
        max_iter = int(cfg['multi_step'][-1][1])
    else:
        max_iter = min(int(cfg['multi_step'][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg['display_iters']))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg['save_iters']))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        if 'efficientnet' in net_type:
            dict = {tstep: it}
            current_lr = sess.run(learning_rate, feed_dict=dict)
        else:
            current_lr = lr_gen.get_lr(it)
            dict = {learning_rate: current_lr}

        [_, loss_val, summary] = sess.run(
            [train_op, total_loss, merged_summaries],
            feed_dict=dict,
        )
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg['snapshot_prefix']
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))
Example #2
0
def train(config_yaml,
          displayiters,
          saveiters,
          maxiters,
          max_to_keep=5,
          keepdeconvweights=True):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    if cfg.dataset_type == 'default' or cfg.dataset_type == 'tensorpack' or cfg.dataset_type == 'deterministic':
        print(
            "Switching batchsize to 1, as default/tensorpack/deterministic loaders do not support batches >1. Use imgaug loader."
        )

        cfg['batch_size'] = 1  #in case this was edited for analysis.-

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    if 'snapshot' in Path(cfg.init_weights).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", cfg.net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", cfg.net_type)
        #loading backbone from ResNet, MobileNet etc.
        if 'resnet' in cfg.net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["resnet_v1"])
        elif 'mobilenet' in cfg.net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"])
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = TF.Session(config=config)
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
def train(config_yaml, displayiters, saveiters, maxiters, max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  # switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  # in case this was edited for analysis.
    dataset = UnsupDataset(cfg)
    #dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    auto = AutoEncoderNet(cfg)
    losses = auto.train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    #restorer = tf.train.Saver()
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    if cfg.init_weights == 'He':
        # Default in ResNet
        print("Random weight initalization using He.")
    else:
        print("Pretrained weight initalization.")
        restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')
    imgs_path = Path(config_yaml).parents[0] / 'imgs'
    imgs_path.mkdir(parents=True)

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    import matplotlib.pyplot as plt
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val, summary, _inp, _outp, _targ,
         _mask] = sess.run([
             train_op, total_loss, merged_summaries, auto.input, auto.output,
             auto.target, auto.mask
         ],
                           feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0:  # and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()
            fig, axs = plt.subplots(2, 2)
            axs[0][0].imshow(_inp[0, :, :, :] / 255)
            axs[0][1].imshow(np.clip(_outp[0, :, :, :], 0, 1))
            axs[1][0].imshow(_targ[0, :, :, :])
            axs[1][1].imshow(_mask[0, :, :, :])
            plt.savefig(str('imgs/pretrain_iter' + str(it) + '.png'),
                        bbox_inches='tight')
            plt.close()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))
Example #4
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=False,
):
    start_path = os.getcwd()
    os.chdir(
        str(Path(config_yaml).parents[0])
    )  # switch to folder of config_yaml (for logging)

    setup_logging()

    cfg = load_config(config_yaml)
    if cfg["optimizer"] != "adam":
        print(
            "Setting batchsize to 1! Larger batchsize not supported for this loader:",
            cfg["dataset_type"],
        )
        cfg["batch_size"] = 1

    if (
        cfg["partaffinityfield_predict"] and "multi-animal" in cfg["dataset_type"]
    ):  # the PAF code currently just hijacks the pairwise net stuff (for the batch feeding via Batch.pairwise_targets: 5)
        print("Activating limb prediction...")
        cfg["pairwise_predict"] = True

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()
    net_type = cfg["net_type"]

    if "snapshot" in Path(cfg["init_weights"]).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"]
            )
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"]
            )
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "")
                + "/ExponentialMovingAverage": var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg["log_dir"], sess.graph)
    learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    restorer.restore(sess, cfg["init_weights"])
    if maxiters == None:
        max_iter = int(cfg["multi_step"][-1][1])
    else:
        max_iter = min(int(cfg["multi_step"][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg["display_iters"]))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg["save_iters"]))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
    lr_gen = LearningRate(cfg)
    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameters:")
    print(cfg)
    print("Starting multi-animal training....")
    for it in range(max_iter + 1):
        if "efficientnet" in net_type:
            dict = {tstep: it}
            current_lr = sess.run(learning_rate, feed_dict=dict)
        else:
            current_lr = lr_gen.get_lr(it)
            dict = {learning_rate: current_lr}

        # [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],feed_dict={learning_rate: current_lr})
        [_, alllosses, loss_val, summary] = sess.run(
            [train_op, losses, total_loss, merged_summaries], feed_dict=dict
        )

        partloss += alllosses["part_loss"]  # scoremap loss
        if cfg["location_refinement"]:
            locrefloss += alllosses["locref_loss"]
        if cfg["pairwise_predict"]:  # paf loss
            pwloss += alllosses["pairwise_loss"]

        cumloss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            logging.info(
                "iteration: {} loss: {} scmap loss: {} locref loss: {} limb loss: {} lr: {}".format(
                    it,
                    "{0:.4f}".format(cumloss / display_iters),
                    "{0:.4f}".format(partloss / display_iters),
                    "{0:.4f}".format(locrefloss / display_iters),
                    "{0:.4f}".format(pwloss / display_iters),
                    current_lr,
                )
            )

            lrf.write(
                "iteration: {}, loss: {}, scmap loss: {}, locref loss: {}, limb loss: {}, lr: {}\n".format(
                    it,
                    "{0:.4f}".format(cumloss / display_iters),
                    "{0:.4f}".format(partloss / display_iters),
                    "{0:.4f}".format(locrefloss / display_iters),
                    "{0:.4f}".format(pwloss / display_iters),
                    current_lr,
                )
            )

            cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg["snapshot_prefix"]
            saver.save(sess, model_name, global_step=it)

    lrf.close()

    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))
Example #5
0
def train(config_yaml, displayiters, saveiters, maxiters, max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    # sess = TF.Session()
    sess = TF.Session(config=TF.ConfigProto(device_count={'GPU': 0}))
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("\n\nMax_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("\nTraining parameter:\n")
    pprint.pprint(cfg)
    print("\n\nStarting training....")
    start = time.time()
    print("\nStarting time of training:  {} \n".format(
        datetime.datetime.now()))
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0:
            end = time.time()
            hours, rem = divmod(end - start, 3600)
            time_hours, time_rem = divmod(end, 3600)
            minutes, seconds = divmod(rem, 60)
            time_mins, _ = divmod(time_rem, 60)
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info(
                "iteration: {}/{},    loss:  {:.4f},    lr: {},  |   Elapsed Time:  {:0>2}:{:0>2}:{:05.2f},    Time:  {}"
                .format(it, max_iter, average_loss, current_lr, int(hours),
                        int(minutes), seconds,
                        datetime.datetime.now().strftime("%H:%M")))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
Example #6
0
def train(config_yaml, displayiters, saveiters, max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    max_iter = int(cfg.multi_step[-1][1])

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
def train(cfg,
          pose_config_yaml,
          displayiters,
          saveiters,
          maxiters,
          max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(pose_config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    pose_cfg = load_config(pose_config_yaml)
    pose_cfg['batch_size'] = 1  #in case this was edited for analysis.

    # TODO:: Cleanup (Setting up validation)
    early_stopping_thresh = 50
    validator = Validator(cfg, pose_cfg, pose_config_yaml)

    dataset = create_dataset(pose_cfg)
    batch_spec = get_batch_spec(pose_cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = (pose_net.PoseNet(pose_cfg)).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(pose_cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, pose_cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    if pose_cfg.init_weights == 'He':
        # Default in ResNet
        print("Random weight initalization using He.")
    else:
        print("Pretrained weight initalization.")
        restorer.restore(sess, pose_cfg.init_weights)
    if maxiters == None:
        max_iter = int(pose_cfg.multi_step[-1][1])
    else:
        max_iter = min(int(pose_cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(pose_cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(pose_cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    # Visualize first layer
    # import numpy as np
    # from skimage import color
    # vars = tf.trainable_variables()
    # print(vars)
    # vars_vals = sess.run(vars[0])
    # vars_vals = np.moveaxis(vars_vals, -1, 0)
    # #vars_vals = (vars_vals - np.amin(vars_vals)) / (np.amax(vars_vals) - np.amin(vars_vals))
    # vars_vals = color.rgb2gray(vars_vals)
    # disp_heatmap(vars_vals, cols=8, title='context2_pe300')

    cum_loss = 0.0
    lr_gen = LearningRate(pose_cfg)
    validerror_min = float('inf')
    last_min = 0

    stats_path = Path(pose_config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(pose_cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = pose_cfg.snapshot_prefix
            print("Calculating validation performance...")
            validerror = validator.validate(sess, it)
            if validerror < validerror_min:
                validerror_min = validerror
                last_min = 0
                saver.save(sess, model_name, global_step=it)
            else:
                last_min += 1
                if last_min > early_stopping_thresh:
                    print(
                        "Early stopping because early_stopping_thresh has been exceeded."
                    )
                    break
    lrf.close()
    validator.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
Example #8
0
def train(config_yaml,
          displayiters,
          saveiters,
          maxiters,
          max_to_keep=5,
          projection_matrices=None,
          multiview_step=None,
          snapshot_index=None):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    cfg['projection_matrices'] = projection_matrices
    cfg['multiview_step'] = multiview_step
    # at this step, jittering the image sizes won't help
    # also, if we jitter the sizes then we would have to undo the jitter before projecting to 3D, so we may as well keep the image size constant
    if multiview_step == 2:
        cfg.global_scale = 1.0
        cfg.scale_jitter_lo = 1.0
        cfg.scale_jitter_up = 1.0
        # also found best results with this optimizer and lr
        print('switching to hardcoded Adam optimizer for this step')
        cfg.optimizer = 'adam'
        cfg.adam_lr = 0.0001

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    if snapshot_index is None:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])
    else:
        variables_to_restore = slim.get_variables_to_restore(exclude=[
            op.name for op in tf.global_variables(scope='.*reweighting.*')
        ])
        cfg.init_weights = os.path.join(os.path.dirname(config_yaml),
                                        'snapshot-%d' % snapshot_index)

    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))