def main():
    workspace_path = os.environ.get('AE_WORKSPACE_PATH')

    if workspace_path == None:
        print('Please define a workspace path:\n')
        print('export AE_WORKSPACE_PATH=/path/to/workspace\n')
        exit(-1)

    parser = argparse.ArgumentParser()
    parser.add_argument("experiment_name")
    parser.add_argument('--at_step', default=None, type=int, required=False)
    parser.add_argument('--model_path', type=str, required=True)
    arguments = parser.parse_args()
    full_name = arguments.experiment_name.split('/')

    experiment_name = full_name.pop()
    experiment_group = full_name.pop() if len(full_name) > 0 else ''
    at_step = arguments.at_step
    model_path = arguments.model_path

    cfg_file_path = u.get_config_file_path(workspace_path, experiment_name,
                                           experiment_group)
    log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group)

    ckpt_dir = u.get_checkpoint_dir(log_dir)
    dataset_path = u.get_dataset_path(workspace_path)

    if not os.path.exists(cfg_file_path):
        print('Could not find config file:\n')
        print('{}\n'.format(cfg_file_path))
        exit(-1)

    args = configparser.ConfigParser()
    args.read(cfg_file_path)
    iteration = args.getint('Training',
                            'NUM_ITER') if at_step is None else at_step

    checkpoint_file_basename = u.get_checkpoint_basefilename(log_dir,
                                                             latest=iteration,
                                                             joint=True)
    if not tf.train.checkpoint_exists(checkpoint_file_basename):
        checkpoint_file_basename = u.get_checkpoint_basefilename(
            log_dir, latest=iteration, joint=False)

    checkpoint_single_encoding = u.get_checkpoint_basefilename(
        log_dir, latest=iteration, model_path=model_path)
    target_checkpoint_file = u.get_checkpoint_basefilename(log_dir, joint=True)

    print(checkpoint_file_basename)
    print(target_checkpoint_file)
    print(ckpt_dir)
    print('#' * 20)

    with tf.variable_scope(experiment_name):
        dataset = factory.build_dataset(dataset_path, args)
        queue = factory.build_queue(dataset, args)
        encoder = factory.build_encoder(queue.x, args)
        # decoder = factory.build_decoder(queue.y, encoder, args)
        # ae = factory.build_ae(encoder, decoder, args)
        # before_cb = set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
        codebook_multi = factory.build_codebook_multi(
            encoder, dataset, args, checkpoint_file_basename)
        restore_saver = tf.train.Saver(save_relative_paths=True,
                                       max_to_keep=100)

        codebook_multi.add_new_codebook_to_graph(model_path)
        # inters_vars = before_cb.intersection(set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        saver = tf.train.Saver(save_relative_paths=True, max_to_keep=100)

    batch_size = args.getint('Training', 'BATCH_SIZE') * len(
        eval(args.get('Paths', 'MODEL_PATH')))
    model = args.get('Dataset', 'MODEL')

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    config = tf.ConfigProto(gpu_options=gpu_options)

    with tf.Session(config=config) as sess:
        print(ckpt_dir)
        # print sess.run(encoder.global_step)
        print('#' * 20)

        # factory.restore_checkpoint(sess, saver, ckpt_dir, at_step=at_step)
        sess.run(tf.global_variables_initializer())
        restore_saver.restore(sess, checkpoint_file_basename)

        print('#' * 20)
        # chkpt = tf.train.get_checkpoint_state(ckpt_dir)
        # if chkpt and chkpt.model_checkpoint_path:
        #     print chkpt.model_checkpoint_path
        #     saver.restore(sess, chkpt.model_checkpoint_path)
        # else:
        #     print 'No checkpoint found. Expected one in:\n'
        #     print '{}\n'.format(ckpt_dir)
        #     exit(-1)

        try:
            loaded_emb = tf.train.load_variable(
                checkpoint_single_encoding,
                experiment_name + '/embedding_normalized')
            loaded_obj_bbs = tf.train.load_variable(
                checkpoint_single_encoding,
                experiment_name + '/embed_obj_bbs_var')
        except:
            loaded_emb = None
            loaded_obj_bbs = None

        if model == 'dsprites':
            codebook_multi.update_embedding_dsprites(sess, args)
        else:
            codebook_multi.update_embedding(sess,
                                            batch_size,
                                            model_path,
                                            loaded_emb=loaded_emb,
                                            loaded_obj_bbs=loaded_obj_bbs)

        print('Saving new checkoint ..')

        saver.save(sess, target_checkpoint_file, global_step=iteration)

        print('done')
Exemplo n.º 2
0
def main():
    tf.disable_eager_execution()

    workspace_path = os.environ.get('AE_WORKSPACE_PATH')

    if workspace_path == None:
        print('Please define a workspace path:\n')
        print('export AE_WORKSPACE_PATH=/path/to/workspace\n')
        exit(-1)

    parser = argparse.ArgumentParser()
    parser.add_argument("experiment_name")
    parser.add_argument('--at_step', default=None, required=False)
    arguments = parser.parse_args()
    full_name = arguments.experiment_name.split('/')

    experiment_name = full_name.pop()
    experiment_group = full_name.pop() if len(full_name) > 0 else ''
    at_step = arguments.at_step

    cfg_file_path = u.get_config_file_path(workspace_path, experiment_name,
                                           experiment_group)
    log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group)
    checkpoint_file = u.get_checkpoint_basefilename(log_dir)
    ckpt_dir = u.get_checkpoint_dir(log_dir)
    dataset_path = u.get_dataset_path(workspace_path)

    print(checkpoint_file)
    print(ckpt_dir)
    print('#' * 20)

    if not os.path.exists(cfg_file_path):
        print('Could not find config file:\n')
        print('{}\n'.format(cfg_file_path))
        exit(-1)

    args = configparser.ConfigParser()
    args.read(cfg_file_path)

    with tf.variable_scope(experiment_name):
        dataset = factory.build_dataset(dataset_path, args)
        queue = factory.build_queue(dataset, args)
        encoder = factory.build_encoder(queue.x, args)
        decoder = factory.build_decoder(queue.y, encoder, args)
        ae = factory.build_ae(encoder, decoder, args)
        codebook = factory.build_codebook(encoder, dataset, args)
        saver = tf.train.Saver(save_relative_paths=True)

    batch_size = args.getint('Training', 'BATCH_SIZE')
    model = args.get('Dataset', 'MODEL')

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
    config = tf.ConfigProto(gpu_options=gpu_options)

    with tf.Session(config=config) as sess:

        print(ckpt_dir)
        print('#' * 20)

        factory.restore_checkpoint(sess, saver, ckpt_dir, at_step=at_step)

        # chkpt = tf.train.get_checkpoint_state(ckpt_dir)
        # if chkpt and chkpt.model_checkpoint_path:
        #     print chkpt.model_checkpoint_path
        #     saver.restore(sess, chkpt.model_checkpoint_path)
        # else:
        #     print 'No checkpoint found. Expected one in:\n'
        #     print '{}\n'.format(ckpt_dir)
        #     exit(-1)

        if model == 'dsprites':
            codebook.update_embedding_dsprites(sess, args)
        else:
            codebook.update_embedding(sess, batch_size)

        print('Saving new checkoint ..')

        saver.save(sess, checkpoint_file, global_step=ae.global_step)

        print('done')
Exemplo n.º 3
0
def main():
    tf.disable_eager_execution()
    
    workspace_path = os.environ.get('AE_WORKSPACE_PATH')

    if workspace_path is None:
        print('Please define a workspace path:\n')
        print('export AE_WORKSPACE_PATH=/path/to/workspace\n')
        exit(-1)

    gentle_stop = np.array((1,), dtype=np.bool)
    gentle_stop[0] = False
    def on_ctrl_c(signal, frame):
        gentle_stop[0] = True
    signal.signal(signal.SIGINT, on_ctrl_c)

    parser = argparse.ArgumentParser()
    parser.add_argument("experiment_name")
    parser.add_argument("-d", action='store_true', default=False)
    parser.add_argument("-gen", action='store_true', default=False)
    arguments = parser.parse_args()

    full_name = arguments.experiment_name.split('/')

    experiment_name = full_name.pop()
    experiment_group = full_name.pop() if len(full_name) > 0 else ''

    debug_mode = arguments.d
    generate_data = arguments.gen

    cfg_file_path = u.get_config_file_path(workspace_path, experiment_name, experiment_group)
    log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group)
    checkpoint_file = u.get_checkpoint_basefilename(log_dir)
    ckpt_dir = u.get_checkpoint_dir(log_dir)
    train_fig_dir = u.get_train_fig_dir(log_dir)
    dataset_path = u.get_dataset_path(workspace_path)

    if not os.path.exists(cfg_file_path):
        print('Could not find config file:\n')
        print('{}\n'.format(cfg_file_path))
        exit(-1)

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if not os.path.exists(train_fig_dir):
        os.makedirs(train_fig_dir)
    if not os.path.exists(dataset_path):
        os.makedirs(dataset_path)

    args = configparser.ConfigParser()
    args.read(cfg_file_path)

    shutil.copy2(cfg_file_path, log_dir)

    with tf.variable_scope(experiment_name):
        dataset = factory.build_dataset(dataset_path, args)
        queue = factory.build_queue(dataset, args)
        encoder = factory.build_encoder(queue.x, args, is_training=True)
        decoder = factory.build_decoder(queue.y, encoder, args, is_training=True)
        ae = factory.build_ae(encoder, decoder, args)
        codebook = factory.build_codebook(encoder, dataset, args)
        train_op = factory.build_train_op(ae, args)
        saver = tf.train.Saver(save_relative_paths=True)

    num_iter = args.getint('Training', 'NUM_ITER') if not debug_mode else 100000
    save_interval = args.getint('Training', 'SAVE_INTERVAL')
    model_type = args.get('Dataset', 'MODEL')

    if model_type=='dsprites':
        dataset.get_sprite_training_images(args)
    else:
        dataset.get_training_images(dataset_path, args)
        dataset.load_bg_images(dataset_path)

    if generate_data:
        print('finished generating synthetic training data for ' + experiment_name)
        print('exiting...')
        exit()

    widgets = ['Training: ', progressbar.Percentage(),
         ' ', progressbar.Bar(),
         ' ', progressbar.Counter(), ' / %s' % num_iter,
         ' ', progressbar.ETA(), ' ']
    bar = progressbar.ProgressBar(maxval=num_iter,widgets=widgets)


    gpu_options = tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=0.8)
    config = tf.ConfigProto(gpu_options=gpu_options)

    with tf.Session(config=config) as sess:

        chkpt = tf.train.get_checkpoint_state(ckpt_dir)
        if chkpt and chkpt.model_checkpoint_path:
            saver.restore(sess, chkpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        merged_loss_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph)


        if not debug_mode:
            print('Training with %s model' % args.get('Dataset','MODEL'), os.path.basename(args.get('Paths','MODEL_PATH')))
            bar.start()

        queue.start(sess)
        for i in range(ae.global_step.eval(), num_iter):
            if not debug_mode:
                sess.run(train_op)
                if i % 10 == 0:
                    loss = sess.run(merged_loss_summary)
                    summary_writer.add_summary(loss, i)

                bar.update(i)
                if (i+1) % save_interval == 0:
                    saver.save(sess, checkpoint_file, global_step=ae.global_step)

                    this_x, this_y = sess.run([queue.x, queue.y])
                    reconstr_train = sess.run(decoder.x,feed_dict={queue.x:this_x})
                    train_imgs = np.hstack(( u.tiles(this_x, 4, 4), u.tiles(reconstr_train, 4,4),u.tiles(this_y, 4, 4)))
                    cv2.imwrite(os.path.join(train_fig_dir,'training_images_%s.png' % i), train_imgs*255)
            else:

                this_x, this_y = sess.run([queue.x, queue.y])
                reconstr_train = sess.run(decoder.x,feed_dict={queue.x:this_x})
                cv2.imshow('sample batch', np.hstack(( u.tiles(this_x, 3, 3), u.tiles(reconstr_train, 3,3),u.tiles(this_y, 3, 3))) )
                k = cv2.waitKey(0)
                if k == 27:
                    break

            if gentle_stop[0]:
                break

        queue.stop(sess)
        if not debug_mode:
            bar.finish()
        if not gentle_stop[0] and not debug_mode:
            print('To create the embedding run:\n')
            print('ae_embed {}\n'.format(full_name))