Exemplo n.º 1
0
def test_classifier():
    with open('embeddings.pkl', 'rb') as f:
        embeddings = pickle.load(f)
    scores = []
    classifiers = ['knn', 'svm', 'mlp']
    for classifier in classifiers:
        success, data = run_evaluation(embeddings,
                                       selected_feature_areas=[2, 3, 4, 5, 6],
                                       classifier=classifier)
        if not success:
            print(data)
            return
        _, averages = data
        score = 0
        for model in ['across_areas', 'within_areas', 'individual_languages']:
            score += averages[model]['total']['score']
        score /= 3
        scores.append(score)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.bar(range(1, 4), scores)
    ax.set_xticks(range(1, 4))
    ax.set_xticklabels(classifiers)
    ax.set_xlabel('Classifier')
    ax.set_ylabel('Score')
    plt.show()
Exemplo n.º 2
0
def test_mlp_graph():
    with open('featurevectors.pkl', 'rb') as f:
        embeddings = pickle.load(f)
    scores = [[], [], []]
    for i in range(0, 3):
        for size in range(10, 101, 10):
            arg = [size for _ in range(i + 1)]
            print(arg)
            success, data = run_evaluation(
                embeddings,
                selected_feature_areas=[2, 3, 4, 5, 6],
                classifier_args=arg,
                classifier='mlp')
            if not success:
                print(data)
                return
            _, averages = data
            score = 0
            for model in [
                    'across_areas', 'within_areas', 'individual_languages'
            ]:
                score += averages[model]['total']['score']
            score /= 3
            scores[i].append(score)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(range(10, 101, 10), scores[0], label='1 layer')
    ax.plot(range(10, 101, 10), scores[1], label='2 layers')
    ax.plot(range(10, 101, 10), scores[2], label='3 layers')
    ax.set_xticklabels(range(10, 101, 10))
    ax.legend()
    ax.set_xlabel('Size of layers')
    ax.set_ylabel('Score')
    plt.show()
Exemplo n.º 3
0
def run_tests():
    with open('embeddings.pkl', 'rb') as f:
        embeddings = pickle.load(f)

    langs = list(embeddings.keys())

    test_num = 1

    while True:
        print('Test number {}'.format(test_num))
        test_num += 1
        r.shuffle(langs)
        number_of_langs = r.randrange(0, len(embeddings))
        random_embeddings = {}
        for i, lang in enumerate(langs):
            random_embeddings[lang] = embeddings[lang]
            if i > number_of_langs:
                break

        feature_area_chance = r.random()
        feature_areas = [
            val for val in range(0, 13) if r.random() > feature_area_chance
        ]

        classifier = ['knn', 'svm', 'mlp'][r.randrange(0, 3)]
        if classifier == 'knn':
            classifier_arg = [r.randrange(1, 50)]
        else:
            classifier_arg = [
                r.randrange(1, 200) for _ in range(r.randrange(1, 10))
            ]
        try:
            run_evaluation(random_embeddings, True, True, classifier_arg,
                           classifier, feature_areas, None)
        except Exception as e:
            print(e)
            print('Embeddings:')
            print(random_embeddings)
            print('')
            print('Feature areas:')
            print(feature_areas)
            print('Classifier: {}'.format(classifier))
            print('Classifier arg: {}'.format(classifier_arg))
            return
Exemplo n.º 4
0
def start_evaluation(args, session):
    try:
        data = run_evaluation(args['embeddings'], False, True,
                              args['classifier_arg'], args['classifier'],
                              args['feature_groups'], args['features'],
                              'temp/' + session + '/')

        with open('temp/' + session + '/data.pkl', 'wb') as f:
            pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
    except Exception as e:
        with open('temp/' + session + '/data.pkl', 'wb') as f:
            pickle.dump((False, e), f, pickle.HIGHEST_PROTOCOL)
Exemplo n.º 5
0
def run_test(args):

    print("Starting evaluation with the following parameters:")
    args.toString()
    query_embs = embed.run_embedding(args, args.query_dataset)
    gallery_embs = embed.run_embedding(args, args.gallery_dataset)

    if args.image_to_track:
        [mAP, rank1] = evaluate_tracks.run_evaluation(args, query_embs,
                                                      gallery_embs)
    else:
        [mAP, rank1] = evaluate.run_evaluation(args, query_embs, gallery_embs)

    return [mAP, rank1]
Exemplo n.º 6
0
def run_test(stored_args):
    filename = stored_args.filename
    stored_args.gallery_dataset = os.path.join(
        'datasets', stored_args.dataset + '_gallery.txt')
    stored_args.query_dataset = os.path.join(
        'datasets', stored_args.dataset + '_query.txt')
    if stored_args.dataset == 'vehicle': stored_args.excluder = 'diagonal'
    elif stored_args.dataset == 'cuhk03': stored_args.excluder = 'PVUD'
    else: stored_args.excluder = stored_args.dataset
    output_file = open(
        os.path.join(stored_args.experiment_root,
                     stored_args.output_name + '.txt'), "a")
    query_embs = embed.run_embedding(stored_args, stored_args.query_dataset)
    stored_args.filename = filename
    gallery_embs = embed.run_embedding(stored_args,
                                       stored_args.gallery_dataset)
    [mAP, rank1] = evaluate.run_evaluation(stored_args, query_embs,
                                           gallery_embs)
    print("mAP: " + str(mAP) + "; rank-1: " + str(rank1))
    output_file.write("checkpoint: " + "mAP: " + str("%0.2f" % (mAP * 100)) +
                      "; rank-1: " + str("%0.2f" % (rank1 * 100)) + "\n")
    output_file.close()
    return [mAP, rank1]
Exemplo n.º 7
0
def test_k_graph():
    with open('featurevectors.pkl', 'rb') as f:
        embeddings = pickle.load(f)
    scores = []
    for i in range(1, 51):
        print(i)
        success, data = run_evaluation(embeddings,
                                       selected_feature_areas=[2, 3, 4, 5, 6],
                                       classifier_args=[i])
        if not success:
            print(data)
            return
        _, averages = data
        score = 0
        for model in ['across_areas', 'within_areas', 'individual_languages']:
            score += averages[model]['total']['score']
        score /= 3
        scores.append(score)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(range(1, 51), scores)
    ax.set_xlabel('k')
    ax.set_ylabel('Score')
    plt.show()
Exemplo n.º 8
0
def run():
    path_out = f'{ROOT_PATH}results/evaluation_collaborative.csv'
    run_evaluation(path_out, get_user_representation, get_comment_section_representation)
Exemplo n.º 9
0
def run():
    path_out = f'{ROOT_PATH}results/evaluation_random.csv'
    run_evaluation(path_out, None, None, random=True)
def train(sess, sv, model, char_embedding_model, train_batches, dev_batches,
          num_train_examples, num_dev_examples, train_batcher,
          labels_str_id_map, labels_id_str_map, train_op, frontend_saver,
          vocab_str_id_map):
    '''
    this fuction actually runs the training loop for the desired number of epochs

    
    :param sess: 
    :param sv: 
    :param model: 
    :param char_embedding_model: 
    :param train_batches: 
    :param dev_batches: 
    :param num_train_examples: 
    :param num_dev_examples: 
    :param train_batcher: 
    :param labels_str_id_map: 
    :param labels_id_str_map: 
    :param train_op: 
    :param frontend_saver: 
    :param vocab_str_id_map: 
    :return: 
    '''

    # print out logging information after every fifth of the training set
    log_every = int(max(100, num_train_examples / 5))

    # start the training loop
    print("Training on %d examples)" % (num_train_examples))
    sys.stdout.flush()
    start_time = time.time()
    train_batcher._step = 1.0
    converged = False

    # keep track of how many training examples we've seen
    examples = 0
    log_every_running = log_every
    epoch_loss = 0.0
    num_lower = 0
    training_iteration = 0
    speed_num = 0.0
    speed_denom = 0.0
    total_iterations = 0
    max_lower = 6
    min_iters = 20
    best_score = 0
    update_frontend = True

    # start training loop, run until max epochs is exceeded or until loss has "converged"
    while not sv.should_stop(
    ) and training_iteration < FLAGS.max_epochs and not (
            FLAGS.until_convergence and converged):
        # if we've gone through the entire dataset, update the epoch count (epoch = iteration here)
        if examples >= num_train_examples:
            training_iteration += 1
            print("iteration %d" % training_iteration)
            sys.stdout.flush()

            if FLAGS.train_eval:
                # print(len(train_batches))
                # print(len(train_batches[0]))
                # (sess, model, char_embedding_model, eval_batches, extra_text="")
                evaluation.run_evaluation(
                    sess, model, char_embedding_model, train_batches,
                    labels_str_id_map, labels_id_str_map,
                    "TRAIN (iteration %d)" % training_iteration)
                print()
            weighted_f1, accuracy, preds, labels = evaluation.run_evaluation(
                sess, model, char_embedding_model, dev_batches,
                labels_str_id_map, labels_id_str_map,
                "TEST (iteration %d)" % training_iteration)
            print()
            # f1_micro, precision = evaluation.run_evaluation(dev_batches, update_context,
            #                                      "TEST (iteration %d)" % training_iteration)
            print("Avg training speed: %f examples/second" %
                  (speed_num / speed_denom))

            # keep track of best weighted F1 score on the dev set, save the model associated with it
            if weighted_f1 > best_score:
                best_score = weighted_f1
                num_lower = 0
                if FLAGS.model_dir != '':
                    if update_frontend:
                        save_path = frontend_saver.save(
                            sess, FLAGS.model_dir + "-frontend.tf")
                        print("Serialized model: %s" % save_path)
            else:
                num_lower += 1

            # if we've done the minimum number of iterations, check to see if the best score has converged
            if num_lower > max_lower and training_iteration > min_iters:
                converged = True

            # update per-epoch variables
            log_every_running = log_every
            examples = 0
            epoch_loss = 0.0
            start_time = time.time()

        # print out logging information
        if examples > log_every_running:
            speed_denom += time.time() - start_time
            speed_num += examples
            evaluation.print_training_error(examples, start_time, epoch_loss,
                                            train_batcher._step)
            sys.stdout.flush()
            log_every_running += log_every

        # train iteration
        # if we're not through an epoch yet, do training as usual
        label_batch, token_batch, shape_batch, char_batch, seq_len_batch, tok_lengths_batch,\
            widths_batch, heights_batch, wh_ratios_batch, x_coords_batch, y_coords_batch,\
            page_ids_batch, line_ids_batch, zone_ids_batch, \
            place_scores_batch, department_scores_batch, university_scores_batch, person_scores_batch= \
            train_batcher.next_batch() if FLAGS.memmap_train else sess.run(train_batcher.next_batch_op)

        # apply word dropout
        # create word dropout mask
        word_probs = np.random.random(token_batch.shape)
        drop_indices = np.where((word_probs > FLAGS.word_dropout))
        token_batch[drop_indices[0],
                    drop_indices[1]] = vocab_str_id_map["<OOV>"]

        # check that shapes look correct
        # print("label_batch_shape: ", label_batch.shape)
        # print("token batch shape: ", token_batch.shape)
        # print("shape batch shape: ", shape_batch.shape)
        # print("char batch shape: ", char_batch.shape)
        # print("seq_len batch shape: ", seq_len_batch.shape)
        # print("tok_len batch shape: ", tok_lengths_batch.shape)
        # print("widths_batch shape: ", widths_batch.shape)
        # print("heights_batch shape: ", heights_batch.shape)
        # print("ratios_batch shape: ", wh_ratios_batch.shape)
        # print("x_coords shape: ", x_coords_batch.shape)
        # print("y_coords shape: ", y_coords_batch.shape)
        # print("pages shape: ", page_ids_batch.shape)
        # print("lines shape: ", line_ids_batch.shape)
        # print("zones shape: ", zone_ids_batch.shape)
        #
        # print("Max sequence length in batch: %d" % np.max(seq_len_batch))
        # sys.stdout.flush()

        # reshape the features to be 3d tensors with 3rd dim = 1 (batch size) x (seq_len) x (1)
        # print("Reshaping features....")
        widths_batch = np.expand_dims(widths_batch, axis=2)
        heights_batch = np.expand_dims(heights_batch, axis=2)
        wh_ratios_batch = np.expand_dims(wh_ratios_batch, axis=2)
        x_coords_batch = np.expand_dims(x_coords_batch, axis=2)
        y_coords_batch = np.expand_dims(y_coords_batch, axis=2)
        page_ids_batch = np.expand_dims(page_ids_batch, axis=2)
        line_ids_batch = np.expand_dims(line_ids_batch, axis=2)
        zone_ids_batch = np.expand_dims(zone_ids_batch, axis=2)

        # make mask out of seq lens
        batch_size, batch_seq_len = token_batch.shape

        # print(batch_seq_len)

        # pad the character batch?
        char_lens = np.sum(tok_lengths_batch, axis=1)
        max_char_len = np.max(tok_lengths_batch)
        padded_char_batch = np.zeros(
            (batch_size, max_char_len * batch_seq_len))
        for b in range(batch_size):
            char_indices = [
                item for sublist in [
                    range(i * max_char_len, i * max_char_len + d)
                    for i, d in enumerate(tok_lengths_batch[b])
                ] for item in sublist
            ]
            padded_char_batch[b, char_indices] = char_batch[b][:char_lens[b]]

        # print(seq_len_batch)
        # print(num_sentences_batch)
        pad_width = 0

        # create masks for each example based on sequence lengths
        mask_batch = np.zeros((batch_size, batch_seq_len))
        for i, seq_lens in enumerate(seq_len_batch):
            start = pad_width
            for seq_len in seq_lens:
                mask_batch[i, start:start + seq_len] = 1
                start += seq_len  #+ (2 if FLAGS.start_end else 1) * pad_width
        examples += batch_size

        # MODEL FEEDS
        char_embedding_feeds = {} if FLAGS.char_dim == 0 else {
            char_embedding_model.input_chars: padded_char_batch,
            char_embedding_model.batch_size: batch_size,
            char_embedding_model.max_seq_len: batch_seq_len,
            char_embedding_model.token_lengths: tok_lengths_batch,
            char_embedding_model.max_tok_len: max_char_len,
            char_embedding_model.input_dropout_keep_prob:
            FLAGS.char_input_dropout
        }

        lstm_feed = {
            model.input_x1: token_batch,
            model.input_x2: shape_batch,
            model.input_y: label_batch,
            model.input_mask: mask_batch,
            model.sequence_lengths: seq_len_batch,
            model.max_seq_len: batch_seq_len,
            model.batch_size: batch_size,
            model.hidden_dropout_keep_prob: FLAGS.hidden_dropout,
            model.middle_dropout_keep_prob: FLAGS.middle_dropout,
            model.input_dropout_keep_prob: FLAGS.input_dropout,
            model.l2_penalty: FLAGS.l2,
            model.drop_penalty: FLAGS.regularize_drop_penalty
        }

        geometric_feats_feeds = {
            model.widths: widths_batch,
            model.heights: heights_batch,
            model.wh_ratios: wh_ratios_batch,
            model.x_coords: x_coords_batch,
            model.y_coords: y_coords_batch,
            model.pages: page_ids_batch,
            model.lines: line_ids_batch,
            model.zones: zone_ids_batch,
        }

        lexicon_feats_feeds = {
            model.place_scores: place_scores_batch,
            model.department_scores: department_scores_batch,
            model.university_scores: university_scores_batch,
            model.person_scores: person_scores_batch
        }

        lstm_feed.update(char_embedding_feeds)

        if FLAGS.use_geometric_feats:
            lstm_feed.update(geometric_feats_feeds)

        if FLAGS.use_lexicons:
            lstm_feed.update(lexicon_feats_feeds)

        # print("Running training op:")
        sys.stdout.flush()
        # tf.Print(model.flat_sequence_lengths, [model.flat_sequence_lengths])
        _, loss = sess.run([train_op, model.loss], feed_dict=lstm_feed)

        epoch_loss += loss
        train_batcher._step += 1

    return best_score, training_iteration, speed_num / speed_denom
Exemplo n.º 11
0
def test():
    embeddings = {}
    with open('embeddings-test.pkl', 'wb') as f:
        pickle.dump(embeddings, f, pickle.HIGHEST_PROTOCOL)
    run_evaluation(embeddings, True, True, 7, 'knn', [3])
Exemplo n.º 12
0
                                losses_dict,
                                checkpoint_fp=checkpoint_fp)

                experiments = [(load_casi, 'casi'), (load_mimic, 'mimic'),
                               (load_columbia, 'columbia')]
                for loader, dataset in experiments:
                    args.lm_type = 'bsg'
                    args.lm_experiment = args.experiment
                    args.ckpt = None
                    args.device = device_str
                    prev_epoch_ct = args.epochs
                    args.epochs = 0
                    block_print()
                    metrics = run_evaluation(args,
                                             BSGAcronymExpander,
                                             loader,
                                             restore_model,
                                             train_frac=0)
                    enable_print()
                    args.epochs = prev_epoch_ct
                    metrics['dataset'] = dataset
                    metrics['hours'] = duration_in_hours
                    metrics['examples'] = full_example_ct
                    metrics['epoch'] = epoch
                    metrics['lm_recon'] = losses_dict['losses']['recon']
                    metrics['lm_kl'] = losses_dict['losses']['kl']
                    row = [metrics[col] for col in metric_cols]
                    metrics_writer.writerow(row)
                    print(metric_cols)
                    print(row)
                    metrics_file.flush()
Exemplo n.º 13
0
def run():
    path_out = ROOT_PATH + '/results/evaluation_data_titles.csv'
    run_evaluation(path_out, get_user_representation, get_comment_section_representation, rep_transform)
Exemplo n.º 14
0
def run():
    path_out = ROOT_PATH + '/results/evaluation_bow.csv'
    run_evaluation(path_out, get_author_tfidf_representation, get_comment_section_representation)
Exemplo n.º 15
0
    params = parse_args()
    eval_bool_dict = get_evaluated_booleans()
    print(params)

    # Vary the number of particles
    task = params.config[params.config.find('-') + 1:-5]
    assert task in ['tracking', 'localization', 'global', 'tworoom']
    if task == 'tracking':
        iters['num_particles'] = [50, 100, 300]
    else:
        iters['num_particles'] = [100, 300, 600]

    iter_values = []
    for v in iters.values():
        iter_values.append(v)

    count = 0
    for t in itertools.product(*iter_values):
        print(count, t)
        count += 1
        params.model = t[2]
        params.num_particles = t[1]
        params.testfiles = t[0]

        # Read from result file to resume evaluation
        filename = result_file(params)
        comb = get_combinations(filename)
        comb_bools = eval_bool_dict.get(comb, np.arange(num_instances) * 0.)
        run_evaluation(params, comb_bools)
Exemplo n.º 16
0
def train(args):
    """ Test to make sure project transform correctly maps points """

    N = args.n_frames
    model = RaftSLAM(args)
    model.cuda()
    model.train()

    if args.ckpt is not None:
        model.load_state_dict(torch.load(args.ckpt))

    db = dataset_factory(args.datasets, n_frames=N, fmin=16.0, fmax=96.0)
    train_loader = DataLoader(db, batch_size=args.batch, shuffle=True, num_workers=4)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
        args.lr, args.steps, pct_start=0.01, cycle_momentum=False)

    logger = Logger(args.name, scheduler)
    should_keep_training = True
    total_steps = 0

    while should_keep_training:
        for i_batch, item in enumerate(train_loader):
            optimizer.zero_grad()

            graph = OrderedDict()
            for i in range(N):
                graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2]
            
            images, poses, depths, intrinsics = [x.to('cuda') for x in item]
            
            # convert poses w2c -> c2w
            Ps = SE3(poses).inv()
            Gs = SE3.Identity(Ps.shape, device='cuda')

            images = normalize_images(images)
            Gs, residuals = model(Gs, images, depths, intrinsics, graph, num_steps=args.iters)

            geo_loss, geo_metrics = geodesic_loss(Ps, Gs, graph)
            res_loss, res_metrics = residual_loss(residuals)

            metrics = {}
            metrics.update(geo_metrics)
            metrics.update(res_metrics)

            loss = args.w1 * geo_loss + args.w2 * res_loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            scheduler.step()
            
            logger.push(metrics)
            total_steps += 1

            if total_steps % 10000 == 0:
                PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps)
                torch.save(model.state_dict(), PATH)

                run_evaluation(PATH)

            if total_steps >= args.steps:
                should_keep_training = False
                break

    return model
def run_train():
    '''
    This method creates the TensorFlow graph and session, running the training loop
    :return: 
    
    '''

    # load preprocessed token, label, shape, char maps
    labels_str_id_map, labels_id_str_map, vocab_str_id_map, vocab_id_str_map, \
    shape_str_id_map, shape_id_str_map, char_str_id_map, char_id_str_map = load_intmaps(FLAGS.train_dir)

    # create intmaps for label types and bio (used later for evaluation, calculating F1 scores, etc.)
    # TODO right now these aren't used
    type_int_int_map, bilou_int_int_map, type_set, bilou_set = create_type_maps(
        labels_str_id_map)

    # load the embeddings
    embeddings = load_embeddings(vocab_str_id_map)

    labels_size = len(labels_str_id_map)
    char_domain_size = len(char_id_str_map)
    vocab_size = len(vocab_str_id_map)
    shape_domain_size = len(shape_id_str_map)

    # create TF graph
    with tf.Graph().as_default():
        # create batchers
        train_batcher = Batcher(
            FLAGS.train_dir,
            FLAGS.batch_size) if FLAGS.memmap_train else SeqBatcher(
                FLAGS.train_dir, FLAGS.batch_size)
        dev_batcher = SeqBatcher(FLAGS.dev_dir,
                                 FLAGS.batch_size,
                                 num_buckets=0,
                                 num_epochs=1)

        train_eval_batcher = SeqBatcher(FLAGS.train_dir,
                                        FLAGS.batch_size,
                                        num_buckets=0,
                                        num_epochs=1)

        # create character embedding model
        if FLAGS.char_dim > 0 and FLAGS.char_model == "lstm":
            print("creating and training character embeddings")
            char_embedding_model = BiLSTMChar(char_domain_size, FLAGS.char_dim,
                                              int(FLAGS.char_tok_dim / 2))
        # elif FLAGS.char_dim > 0 and FLAGS.char_model == "cnn":
        #     char_embedding_model = CNNChar(char_domain_size, FLAGS.char_dim, FLAGS.char_tok_dim, layers_map[0][1]['width'])
        else:
            char_embedding_model = None
        char_embeddings = char_embedding_model.outputs if char_embedding_model is not None else None

        # create BiLSTM model
        if FLAGS.model == 'bilstm':
            model = BiLSTM(
                num_classes=labels_size,
                vocab_size=vocab_size,
                shape_domain_size=shape_domain_size,
                char_domain_size=char_domain_size,
                char_size=FLAGS.char_dim,
                embedding_size=FLAGS.embed_dim,
                shape_size=FLAGS.shape_dim,
                lex_size=FLAGS.lex_dim,
                nonlinearity=FLAGS.nonlinearity,
                viterbi=False,  #viterbi=FLAGS.viterbi,
                hidden_dim=FLAGS.lstm_dim,
                char_embeddings=char_embeddings,
                embeddings=embeddings,
                use_geometric_feats=FLAGS.use_geometric_feats,
                use_lexicons=FLAGS.use_lexicons)
        # elif FLAGS.model == 'lstm':
        #     model = LSTM(
        #         num_classes=labels_size,
        #         vocab_size=vocab_size,
        #         shape_domain_size=shape_domain_size,
        #         char_domain_size=char_domain_size,
        #         char_size=FLAGS.char_dim,
        #         embedding_size=FLAGS.embed_dim,
        #         shape_size=FLAGS.shape_dim,
        #         nonlinearity=FLAGS.nonlinearity,
        #         viterbi=False,  # viterbi=FLAGS.viterbi,
        #         hidden_dim=FLAGS.lstm_dim,
        #         char_embeddings=char_embeddings,
        #         embeddings=embeddings,
        #         use_geometric_feats=FLAGS.use_geometric_feats,
        #         use_lexicons=FLAGS.use_lexicons)

        # Define Training procedure
        global_step = tf.Variable(0, name='global_step', trainable=False)

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr,
                                           beta1=FLAGS.beta1,
                                           beta2=FLAGS.beta2,
                                           epsilon=FLAGS.epsilon,
                                           name="optimizer")

        model_vars = [
            v for v in tf.all_variables() if 'context_agg' not in v.name
        ]

        train_op = optimizer.minimize(model.loss,
                                      global_step=global_step,
                                      var_list=model_vars)

        print("model vars: %d" % len(model_vars))
        print(map(lambda v: v.name, model_vars))
        print()
        sys.stdout.flush()
        get_trainable_params()

        tf.initialize_all_variables()

        frontend_opt_vars = [
            optimizer.get_slot(s, n) for n in optimizer.get_slot_names()
            for s in model_vars if optimizer.get_slot(s, n) is not None
        ]

        model_vars += frontend_opt_vars

        # load pretrained model if one is provided
        if FLAGS.load_dir:
            reader = tf.train.NewCheckpointReader(FLAGS.load_dir + ".tf")
            saved_var_map = reader.get_variable_to_shape_map()
            intersect_vars = [
                k for k in tf.all_variables()
                if k.name.split(':')[0] in saved_var_map
                and k.get_shape() == saved_var_map[k.name.split(':')[0]]
            ]
            leftovers = [
                k for k in tf.all_variables()
                if k.name.split(':')[0] not in saved_var_map
                or k.get_shape() != saved_var_map[k.name.split(':')[0]]
            ]
            print("WARNING: Loading pretrained frontend, but not loading: ",
                  map(lambda v: v.name, leftovers))
            frontend_loader = tf.train.Saver(var_list=intersect_vars)

        else:
            frontend_loader = tf.train.Saver(var_list=model_vars)

        frontend_saver = tf.train.Saver(var_list=model_vars)

        # create a supervisor
        sv = tf.python.train.Supervisor(
            logdir=FLAGS.model_dir if FLAGS.model_dir != '' else None,
            global_step=global_step,
            saver=None,
            save_model_secs=0,
            save_summaries_secs=0)

        training_start_time = time.time()

        # create session
        with sv.managed_session(
                FLAGS.master,
                config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            print("session created")
            sys.stdout.flush()

            # start queue runner threads
            threads = tf.train.start_queue_runners(sess=sess)

            # load model if applicable
            if FLAGS.load_dir != '':
                print("Deserializing model: " + FLAGS.load_dir + ".tf")
                frontend_loader.restore(sess, FLAGS.load_dir + ".tf")

            # load batches
            print()
            dev_batches, train_batches, num_dev_examples, num_train_examples \
                = load_batches(sess, train_batcher, train_eval_batcher, dev_batcher)

            # just run the evaluation if applicable
            if FLAGS.evaluate_only:
                if FLAGS.train_eval:
                    w_f1, accuracy, preds, labels = evaluation.run_evaluation(
                        sess, model, char_embedding_model, train_batches,
                        labels_str_id_map, labels_id_str_map, "TRAIN")
                print()
                w_f1, accuracy, preds, labels = evaluation.run_evaluation(
                    sess, model, char_embedding_model, dev_batches,
                    labels_str_id_map, labels_id_str_map, "TEST", True,
                    vocab_str_id_map, vocab_id_str_map)

                # write test set predictions to disk (for furthr analysis)
                print("writing predictions to disk:")
                # with open(FLAGS.model_dir + os.sep + 'test_preds.txt', 'w') as f:
                #     for pred in preds:
                #         f.write(pred + "\n")
                # with open(FLAGS.model_dir + os.sep + 'test_golds.txt', 'w') as f:
                #     for label in labels:
                #         f.write(label + "\n")
                np.save(FLAGS.model_dir + os.sep + "test_preds.npy", preds)
                np.save(FLAGS.model_dir + os.sep + "test_labels.npy", labels)

            # train a model
            else:
                best_score = 0
                total_iterations = 0

                # always train the front-end unless load dir was passed
                if FLAGS.load_dir == '' or (FLAGS.load_dir != ''
                                            and FLAGS.layers2 == ''):
                    best_score, training_iteration, train_speed = train(
                        sess, sv, model, char_embedding_model, train_batches,
                        dev_batches, num_train_examples, num_dev_examples,
                        train_batcher, labels_str_id_map, labels_id_str_map,
                        train_op, frontend_saver, vocab_str_id_map)
                    total_iterations += training_iteration
                    if FLAGS.model_dir:
                        print("Deserializing model: " + FLAGS.model_dir +
                              "-frontend.tf")
                        frontend_saver.restore(
                            sess, FLAGS.model_dir + "-frontend.tf")

        sv.coord.request_stop()
        sv.coord.join(threads)
        sess.close()

    total_time = time.time() - training_start_time
    if FLAGS.evaluate_only:
        print("Testing time: %d seconds" % (total_time))
    else:
        print(
            "Training time: %d minutes, %d iterations (%3.2f minutes/iteration)"
            % (total_time / 60, total_iterations, total_time /
               (60 * total_iterations)))
        print("Avg training speed: %f examples/second" % (train_speed))
        print("Best dev F1: %2.2f" % (best_score * 100))