예제 #1
0
 def record_tvd(self, sess):
     step, tvd, mvd = calc_tvd(sess, self.gen, self.data)
     self.log_tvd(step, tvd, mvd)
     summ_tvd = make_summary(self.data.name + '_tvd', tvd)
     summ_mvd = make_summary(self.data.name + '_mvd', mvd)
     self.summary_writer.add_summary(summ_tvd, step)
     self.summary_writer.add_summary(summ_mvd, step)
     self.summary_writer.flush()
예제 #2
0
def start_summary():
    path = utils.initialize_summary()
    latlong = utils.read_latlong()
    place_list = list(latlong.keys())
    for place in place_list:
        df = utils.load_year(place)
        summary = utils.make_summary(df)
        utils.save_summary(summary, place)
예제 #3
0
파일: train.py 프로젝트: statX/cinc17
def run_validation(model, data_loader, session, summarizer):
    it = model.it.eval(session)
    predictions = []
    labels = []
    losses = []
    for batch in data_loader.val:
        ops = [model.probs, model.loss]
        feed_dict = model.feed_dict(*batch)
        probs, loss = session.run(ops, feed_dict=feed_dict)
        predictions.extend(np.argmax(probs, axis=1).tolist())
        labels.extend(batch[1])
        losses.append(loss)
    loss = np.mean(losses)
    acc = skm.accuracy_score(labels, predictions)
    mac_f1 = utils.cinc_score(labels, predictions)
    summary = utils.make_summary("Dev Accuracy", float(acc))
    summarizer.add_summary(summary, global_step=it)
    summary = utils.make_summary("Dev Loss", float(loss))
    summarizer.add_summary(summary, global_step=it)
    msg = "Validation: Loss {:.3f}, Acc {:.3f}, Macro F1 {:.3f}"
    logger.info(msg.format(loss, acc, mac_f1))
    return acc
def main():
    # manually shuffle the train txt file because tf.data.shuffle is soooo slow!
    shuffle_and_overwrite(os.path.join(TXT_DATA_DIR, TRAIN_FILE))
    # dataset loading using tf.data module
    with tf.device('/cpu:0'):
        train_dataset = tf.data.Dataset.from_tensor_slices(
            [os.path.join(TXT_DATA_DIR, TRAIN_FILE)])
        train_dataset = train_dataset.apply(
            tf.contrib.data.parallel_interleave(
                lambda x: tf.data.TextLineDataset(x).map(
                    lambda x: tf.string_split([x], delimiter=' ').values),
                cycle_length=NUM_THREADS,
                block_length=1))
        train_dataset = train_dataset.apply(
            tf.contrib.data.map_and_batch(lambda x: tuple(
                tf.py_func(get_data_func_train, [x, IMAGE_SIZE],
                           [tf.float32, tf.int64])),
                                          batch_size=BATCH_SIZE,
                                          num_parallel_batches=NUM_THREADS))
        train_dataset.prefetch(PREFETCH_BUFFER)

        val_dataset = tf.data.Dataset.from_tensor_slices(
            [os.path.join(TXT_DATA_DIR, VAL_FILE)])
        val_dataset = val_dataset.shuffle(VAL_LEN)
        val_dataset = val_dataset.apply(
            tf.contrib.data.parallel_interleave(
                lambda x: tf.data.TextLineDataset(x).map(
                    lambda x: tf.string_split([x], delimiter=' ').values),
                cycle_length=NUM_THREADS,
                block_length=1))
        val_dataset = val_dataset.apply(
            tf.contrib.data.map_and_batch(lambda x: tuple(
                tf.py_func(get_data_func_val, [x, IMAGE_SIZE],
                           [tf.float32, tf.int64])),
                                          batch_size=BATCH_SIZE,
                                          num_parallel_batches=NUM_THREADS))
        val_dataset.prefetch(PREFETCH_BUFFER)

        train_iterator = train_dataset.make_initializable_iterator()
        val_iterator = val_dataset.make_initializable_iterator()

        train_handle = train_iterator.string_handle()
        val_handle = val_iterator.string_handle()
        handle_flag = tf.placeholder(tf.string, [],
                                     name='iterator_handle_flag')
        dataset_iterator = tf.data.Iterator.from_string_handle(
            handle_flag, train_dataset.output_types,
            train_dataset.output_shapes)

        batch_vid, batch_label = dataset_iterator.get_next()
        batch_vid.set_shape([None, None, IMAGE_SIZE, IMAGE_SIZE, 3])

    train_flag = tf.placeholder(dtype=tf.bool, name='train_flag')
    dropout_flag = tf.placeholder(dtype=tf.float32, name='dropout_flag')

    # define model here
    with tf.variable_scope('RGB'):
        model = i3d.InceptionI3d(num_classes=400,
                                 spatial_squeeze=True,
                                 final_endpoint='Logits')
        logits, _ = model(inputs=batch_vid,
                          is_training=train_flag,
                          dropout_keep_prob=dropout_flag)
        logits_dropout = tf.nn.dropout(logits, keep_prob=dropout_flag)
        out = tf.layers.dense(logits_dropout,
                              NUM_CLASS,
                              activation=None,
                              use_bias=True)

        is_in_top_K = tf.nn.in_top_k(predictions=out,
                                     targets=batch_label,
                                     k=TOP_K)

        # maintain a variable map to restore from the ckpt
        variable_map = {}
        for var in tf.global_variables():
            var_name_split = var.name.split('/')
            if var_name_split[
                    1] == 'inception_i3d' and 'dense' not in var_name_split[1]:
                variable_map[var.name[:-2]] = var
            if var_name_split[-1][:-2] == 'w' or var_name_split[
                    -1][:-2] == 'kernel':
                tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                     tf.nn.l2_loss(var))

        # optional: print to check the variable names
        # pprint(variable_map)

        regularization_loss = tf.losses.get_regularization_loss(
            name='regularization_loss')  # sum of l2 loss
        loss_cross_entropy = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=batch_label, logits=out, name='cross_entropy'))
        total_loss = tf.add(loss_cross_entropy,
                            L2_PARAM * regularization_loss,
                            name='total_loss')
        tf.summary.scalar('batch_statistics/total_loss', total_loss)
        tf.summary.scalar('batch_statistics/cross_entropy_loss',
                          loss_cross_entropy)
        tf.summary.scalar('batch_statistics/l2_loss', regularization_loss)
        tf.summary.scalar('batch_statistics/loss_ratio',
                          regularization_loss / loss_cross_entropy)

        saver_to_restore = tf.train.Saver(var_list=variable_map, reshape=True)

        batch_num = TRAIN_LEN / BATCH_SIZE

        global_step = tf.Variable(GLOBAL_STEP_INIT,
                                  trainable=False,
                                  collections=[tf.GraphKeys.LOCAL_VARIABLES])
        learning_rate = config_learning_rate(global_step, batch_num)
        tf.summary.scalar('learning_rate', learning_rate)

        # set dependencies for BN ops
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = config_optimizer(OPTIMIZER, learning_rate, OPT_EPSILON)
            train_op = optimizer.minimize(total_loss, global_step=global_step)

        # NOTE: if you don't want to save the params of the optimizer into the checkpoint,
        # you can place this line before the `update_ops` line
        saver_to_save = tf.train.Saver(max_to_keep=40)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            train_handle_value, val_handle_value = sess.run(
                [train_handle, val_handle])
            sess.run(train_iterator.initializer)
            saver_to_restore.restore(sess, CHECKPOINT_PATH)
            merged = tf.summary.merge_all()
            train_writer = tf.summary.FileWriter(
                os.path.join(TENSORBOARD_LOG_DIR, 'train'), sess.graph)

            sys.stdout.write('\n----------- start to train -----------\n')

            intermediate_train_info = [0., 0.]
            for epoch in range(EPOCH_NUM):
                epoch_acc, epoch_loss = 0., 0.
                pbar = tqdm(total=batch_num,
                            desc='Epoch {}'.format(epoch),
                            unit=' batch (batch_size: {})'.format(BATCH_SIZE))
                for i in range(batch_num):
                    _, _loss_cross_entropy, _is_in_top_K, summary, _global_step, lr = sess.run(
                        [
                            train_op, loss_cross_entropy, is_in_top_K, merged,
                            global_step, learning_rate
                        ],
                        feed_dict={
                            train_flag: True,
                            dropout_flag: DROPOUT_KEEP_PRAM,
                            handle_flag: train_handle_value
                        })
                    train_writer.add_summary(summary, global_step=_global_step)

                    intermediate_train_info[0] += np.sum(_is_in_top_K)
                    intermediate_train_info[1] += _loss_cross_entropy
                    epoch_acc += np.sum(_is_in_top_K)
                    epoch_loss += _loss_cross_entropy

                    # intermediate evaluation for the training dataset
                    if _global_step % SHOW_TRAIN_INFO_FREQ == 0:
                        intermediate_train_acc = float(
                            intermediate_train_info[0]) / (
                                SHOW_TRAIN_INFO_FREQ * BATCH_SIZE)
                        intermediate_train_loss = intermediate_train_info[
                            1] / SHOW_TRAIN_INFO_FREQ

                        step_log_info = 'Epoch:{}, global_step:{}, step_train_acc:{:.4f}, step_train_loss:{:4f}, lr:{:.7g}'.format(
                            epoch, _global_step, intermediate_train_acc,
                            intermediate_train_loss, lr)
                        sys.stdout.write('\n' + step_log_info + '\n')
                        sys.stdout.flush()
                        logging.info(step_log_info)
                        train_writer.add_summary(make_summary(
                            'accumulated_statistics/train_acc',
                            intermediate_train_acc),
                                                 global_step=_global_step)
                        train_writer.add_summary(make_summary(
                            'accumulated_statistics/train_loss',
                            intermediate_train_loss),
                                                 global_step=_global_step)
                        intermediate_train_info = [0., 0.]

                    # start to evaluate
                    if _global_step % SAVE_FREQ == 0:
                        if intermediate_train_acc >= 0.8:
                            saver_to_save.save(
                                sess,
                                SAVE_DIR + '/model_step_{}_lr_{:.7g}'.format(
                                    _global_step, lr))

                    pbar.update(1)
                pbar.close()

                # start to validata on the validation dataset
                sess.run(val_iterator.initializer)
                iter_num = int(np.ceil(float(VAL_LEN) / BATCH_SIZE))
                correct_cnt, loss_cnt = 0, 0
                pbar = tqdm(total=iter_num,
                            desc='EVAL train_epoch:{}'.format(epoch),
                            unit=' batch(batch_size={})'.format(BATCH_SIZE))
                for _ in range(iter_num):
                    _is_in_top_K, _loss_cross_entropy = sess.run(
                        [is_in_top_K, loss_cross_entropy],
                        feed_dict={
                            handle_flag: val_handle_value,
                            train_flag: False,
                            dropout_flag: 1.0
                        })
                    correct_cnt += np.sum(_is_in_top_K)
                    loss_cnt += _loss_cross_entropy
                    pbar.update(1)
                pbar.close()
                val_acc = float(correct_cnt) / VAL_LEN
                val_loss = float(loss_cnt) / iter_num

                log_info = '==>> Epoch:{}, global_step:{}, val_acc:{:.4f}, val_loss:{:4f}, lr:{:.7g}'.format(
                    epoch, _global_step, val_acc, val_loss, lr)
                logging.info(log_info)
                sys.stdout.write('\n' + log_info + '\n')
                sys.stdout.flush()

                # manually shuffle the data with python for better performance
                shuffle_and_overwrite(os.path.join(TXT_DATA_DIR, TRAIN_FILE))
                sess.run(train_iterator.initializer)

                epoch_acc = float(epoch_acc) / TRAIN_LEN
                epoch_loss = float(epoch_loss) / batch_num
                log_info = '==========Epoch:{}, whole_train_acc:{:.4f}, whole_train_loss:{:4f}, lr:{:.7g}=========='.format(
                    epoch, epoch_acc, epoch_loss, lr)
                logging.info(log_info)
                sys.stdout.write('\n' + log_info + '\n')
                sys.stdout.flush()

        train_writer.close()
예제 #5
0
  def eval_loop(self, last_global_step):
    """Run the evaluation loop once."""

    latest_checkpoint, global_step = self.get_checkpoint(
      last_global_step)
    logging.info("latest_checkpoint: {}".format(latest_checkpoint))

    if latest_checkpoint is None or global_step == last_global_step:
      time.sleep(self.wait)
      return last_global_step

    with tf.Session(config=self.config) as sess:
      logging.info("Loading checkpoint for eval: {}".format(latest_checkpoint))

      # Restores from checkpoint
      self.saver.restore(sess, latest_checkpoint)
      sess.run(tf.local_variables_initializer())

      epoch = get_epoch(
                global_step,
                FLAGS.train_num_gpu,
                FLAGS.train_batch_size,
                self.reader.n_train_files)

      fetches = OrderedDict(
         loss_update_op=tf_get('loss_update_op'),
         acc_update_op=tf_get('acc_update_op'),
         loss=tf_get('loss'),
         accuracy=tf_get('accuracy'))

      while True:
        try:

          batch_start_time = time.time()
          values = sess.run(list(fetches.values()))
          values = dict(zip(fetches.keys(), values))
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = self.batch_size / seconds_per_batch

          message = MessageBuilder()
          message.add('epoch', epoch, format='.2f')
          message.add('step', global_step)
          message.add('accuracy', values['accuracy'], format='.5f')
          message.add('avg loss', values['loss'], format='.5f')
          message.add('imgs/sec', examples_per_second, format='.0f')
          logging.info(message.get_message())

        except tf.errors.OutOfRangeError:

          if self.best_accuracy is None or self.best_accuracy < values['accuracy']:
            self.best_global_step = global_step
            self.best_accuracy = values['accuracy']

          make_summary("accuracy", values['accuracy'], self.summary_writer, global_step)
          make_summary("loss", values['loss'], self.summary_writer, global_step)
          make_summary("epoch", epoch, self.summary_writer, global_step)
          self.summary_writer.flush()

          message = MessageBuilder()
          message.add('final: epoch', epoch, format='.2f')
          message.add('step', global_step)
          message.add('accuracy', values['accuracy'], format='.5f')
          message.add('avg loss', values['loss'], format='.5f')
          logging.info(message.get_message())
          logging.info("Done with batched inference.")

          if self.stopped_at_n:
           self.counter += 1

          break

      return global_step
def train():
    with tf.Graph().as_default():
        #build model graph
        model = TextCNN(sentence_length=FLAGS.sentence_length,
                        vocab_size=FLAGS.word_size,
                        num_classes=FLAGS.num_class,
                        embedding_size=FLAGS.embedding_size,
                        filter_sizes=list(
                            map(int, FLAGS.filter_size.split(","))),
                        num_filters=FLAGS.num_filters,
                        l2_reg_lambda=0.1)
        #construct data loader grpah
        init_words, init_embeddings = load_word_vector(
            file_name=FLAGS.word_embedding_path,
            vocab_size=FLAGS.word_size,
            embedding_size=FLAGS.embedding_size,
            field_delim=FLAGS.field_delim,
            words=model.vocab,
            embeddings=model.words_embedding)

        train_file_queue = tf.train.string_input_producer(
            [FLAGS.train_data_path], shuffle=True, name="train_load_queue")
        train_reader = tf.TextLineReader()
        train_words, train_labels = load_train_words_and_lables(
            reader=train_reader,
            file_queue=train_file_queue,
            max_sentence_len=FLAGS.sentence_length,
            batch_size=FLAGS.batch_size,
            num_classes=FLAGS.num_class,
            field_delim=FLAGS.field_delim,
            word_delim=FLAGS.word_delim)

        valid_words, valid_labels = load_valid_words_and_labels(
            file_name=FLAGS.predict_data_path,
            valid_size=FLAGS.eval_size,
            max_sentence_len=FLAGS.sentence_length,
            num_classes=FLAGS.num_class,
            field_delim=FLAGS.field_delim,
            word_delim=FLAGS.word_delim)

        train_logits = model.inference(train_words)
        train_loss = model.loss(logits=train_logits, input_y=train_labels)
        train_acc = model.accuracy(logits=train_logits, input_y=train_labels)

        valid_logits = model.inference(valid_words)
        valid_loss = model.loss(logits=valid_logits, input_y=valid_labels)
        valid_acc = model.accuracy(logits=valid_logits, input_y=valid_labels)

        #construct saver
        saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoints)

        #define training procedure, note:batchnorm
        global_step = tf.Variable(0, name="global_step", trainable=False)
        learning_rate = tf.train.exponential_decay(
            learning_rate=3e-4,
            global_step=global_step,
            decay_steps=FLAGS.decay_steps,
            decay_rate=FLAGS.decay_rate,
            staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        grads_and_vars = optimizer.compute_gradients(train_loss)
        train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars,
                                             global_step=global_step)
        #gradients, variables = zip(*optimizer.compute_gradients(train_loss))
        #gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        #update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        #with tf.control_dependencies(update_ops):
        #    train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)

        #train summary
        train_summary_op = make_summary(grads_and_vars, train_loss, train_acc)
        # session config
        #session_conf = tf.ConfigProto(allow_soft_placement=True,
        #                              log_device_placement=FLAGS.log_device_placement)
        session_conf = tf.ConfigProto()
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            #init words and embeddings
            sess.run([init_words, init_embeddings])
            #initialize hash table for looking up word
            tf.tables_initializer().run()

            #feed dict
            param_feed_train = {
                model.dropout_keep_prob: FLAGS.drop_keep_prob,
                model.is_training: True
            }
            param_feed_valid = {
                model.dropout_keep_prob: 1.0,
                model.is_training: False
            }

            timestamp = str(int(time.time()))
            train_summary_dir = os.path.join(FLAGS.summary_dir,
                                             "train_summaries", timestamp)
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            for i in range(FLAGS.num_epochs):
                _, _step, _summaries, _loss, _acc = sess.run(
                    [
                        train_op, global_step, train_summary_op, train_loss,
                        train_acc
                    ],
                    feed_dict=param_feed_train)
                _time_str = datetime.datetime.now().isoformat()
                print("{}: step {}, loss {:g}, acc {:g}".format(
                    _time_str, _step, _loss, _acc))
                train_summary_writer.add_summary(_summaries, _step)

                if (i + 1) % FLAGS.eval_every == 0:
                    _valid_loss, _valid_acc = sess.run(
                        [valid_loss, valid_acc], feed_dict=param_feed_valid)
                    print("valid_loss {:g}, valid_accuracy {:g}".format(
                        _valid_loss, _valid_acc))

            try:
                if FLAGS.online_model == "yes":
                    online_path = export_online(
                        sess=sess,
                        export_path=FLAGS.export_online_path,
                        count=FLAGS.num_epochs,
                        signature=model.signature())
                    print("Save online model {}\n".format(online_path))
                saver.save(sess=sess,
                           save_path=FLAGS.checkpoint_dir,
                           global_step=FLAGS.num_epochs)
            except tf.errors.OutOfRangeError:
                pass

            finally:
                coord.request_stop()
                coord.join(threads)
예제 #7
0
def main():
    with tf.device('/cpu:0'):
        batch_x, batch_y, _dataset_init_op = get_input(NUM_CLASS)
        train_flag = tf.placeholder(dtype=tf.bool, name='train_flag')
        dropout_flag = tf.placeholder(dtype=tf.float32, name='dropout_flag')

        batch_num = TRAIN_LEN / (BATCH_SIZE * N_GPU)

        global_step = tf.Variable(GLOBAL_STEP_INIT,
                                  trainable=False,
                                  collections=[tf.GraphKeys.LOCAL_VARIABLES])
        learning_rate = config_learning_rate(global_step, batch_num)
        tf.summary.scalar('learning_rate', learning_rate)

        optimizer = config_optimizer(OPTIMIZER, learning_rate, OPT_EPSILON)

        cross_entropy_list = []
        in_top_K_list = []

        reuse_flag = False
        tower_grads = []
        for i in range(N_GPU):
            with tf.device('/gpu:{}'.format(i)):
                with tf.name_scope('GPU_{}'.format(i)) as scope:
                    current_loss, _current_cross_entropy, current_in_top_K = get_model_loss(
                        batch_x, batch_y, train_flag, dropout_flag, scope,
                        reuse_flag)
                    reuse_flag = True
                    grads = optimizer.compute_gradients(current_loss)
                    tower_grads.append(grads)
                    cross_entropy_list.append(_current_cross_entropy)
                    in_top_K_list.append(current_in_top_K)

                    # retain the bn ops only from the last tower
                    # as suggested by: https://github.com/tensorflow/models/blob/master/research/inception/inception/inception_train.py#L249
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                   scope)

        grads = average_gradients(tower_grads)
        apply_gradient_op = optimizer.apply_gradients(grads,
                                                      global_step=global_step)

        with tf.control_dependencies(update_ops):
            train_op = tf.group(apply_gradient_op, name='train_op')

        # maintain a variable map to restore from the ckpt
        variable_map = {}
        for var in tf.global_variables():
            var_name_split = var.name.split('/')
            if var_name_split[1] == 'inception_i3d' and 'dense' not in var.name:
                variable_map[var.name[:-2]] = var

        saver_to_restore = tf.train.Saver(var_list=variable_map, reshape=True)
        saver_to_save = tf.train.Saver(max_to_keep=50)

        # NOTE: optional: check variable names
        # for var in tf.global_variables():
        #     print(var.name)

        # TODO: may var moving average here?

        average_cross_entropy = tf.reduce_mean(cross_entropy_list)
        average_in_top_K = tf.reduce_sum(in_top_K_list)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            shuffle_and_overwrite(os.path.join(TXT_DATA_DIR, TRAIN_FILE))
            sess.run(_dataset_init_op)
            saver_to_restore.restore(sess, CHECKPOINT_PATH)
            merged = tf.summary.merge_all()
            train_writer = tf.summary.FileWriter(
                os.path.join(TENSORBOARD_LOG_DIR, 'train'), sess.graph)

            sys.stdout.write('\n----------- start to train -----------\n')
            sys.stdout.flush()

            intermediate_train_info = [0., 0.]
            for epoch in range(EPOCH_NUM):

                pbar = tqdm(total=batch_num,
                            desc='Epoch {}'.format(epoch),
                            unit=' batch (batch_size: {} * {} GPUs)'.format(
                                BATCH_SIZE, N_GPU))
                epoch_acc, epoch_loss = 0., 0.
                for i in range(batch_num):
                    _, _loss_cross_entropy, _is_in_top_K, summary, _global_step, lr = sess.run(
                        [
                            train_op, average_cross_entropy, average_in_top_K,
                            merged, global_step, learning_rate
                        ],
                        feed_dict={
                            train_flag: True,
                            dropout_flag: DROPOUT_KEEP_PRAM
                        })
                    train_writer.add_summary(summary, global_step=_global_step)

                    intermediate_train_info[0] += _is_in_top_K
                    epoch_acc += _is_in_top_K
                    intermediate_train_info[1] += _loss_cross_entropy
                    epoch_loss += _loss_cross_entropy

                    # intermediate evaluation for the training dataset
                    if _global_step % SHOW_TRAIN_INFO_FREQ == 0:
                        intermediate_train_acc = float(
                            intermediate_train_info[0]) / (
                                SHOW_TRAIN_INFO_FREQ * BATCH_SIZE * N_GPU)
                        intermediate_train_loss = intermediate_train_info[
                            1] / (SHOW_TRAIN_INFO_FREQ)

                        step_log_info = 'global_step:{}, step_train_acc:{:.4f}, step_train_loss:{:4f}, lr:{:.4g}'.format(
                            _global_step, intermediate_train_acc,
                            intermediate_train_loss, lr)
                        sys.stdout.write('\n' + step_log_info + '\n')
                        sys.stdout.flush()
                        logging.info(step_log_info)
                        train_writer.add_summary(make_summary(
                            'accumulated_statistics/train_acc',
                            intermediate_train_acc),
                                                 global_step=_global_step)
                        train_writer.add_summary(make_summary(
                            'accumulated_statistics/train_loss',
                            intermediate_train_loss),
                                                 global_step=_global_step)
                        intermediate_train_info = [0., 0.]

                    # start to save
                    if _global_step % SAVE_FREQ == 0:
                        saver_to_save.save(
                            sess, SAVE_DIR +
                            '/model_step_{}_train_acc_{:.4f}_lr_{:.4g}'.format(
                                _global_step, intermediate_train_acc, lr))

                    pbar.update(1)
                pbar.close()

                epoch_acc = float(epoch_acc) / TRAIN_LEN
                epoch_loss = float(epoch_loss) / batch_num
                log_info = '=====Epoch:{}, whole_train_acc:{:.4f}, whole_train_loss:{:4f}, lr:{:.7g}====='.format(
                    epoch, epoch_acc, epoch_loss, lr)
                logging.info(log_info)
                sys.stdout.write('\n' + log_info + '\n')
                sys.stdout.flush()

                shuffle_and_overwrite(os.path.join(TXT_DATA_DIR, TRAIN_FILE))
                sess.run(_dataset_init_op)
            train_writer.close()
예제 #8
0
    def pretrain_loop(self, num_iter=None):
        '''
        num_iter : is the number of *additional* iterations to do
        baring one of the quit conditions (the model may already be
        trained for some number of iterations). Defaults to
        cc_config.pretrain_iter.

        '''
        #TODO: potentially should be moved into CausalController for consistency

        num_iter = num_iter or self.cc.config.pretrain_iter

        if hasattr(self, 'model'):
            model_step = self.sess.run(self.model.step)
            assert model_step == 0, 'if pretraining, model should not be trained already'

        cc_step = self.sess.run(self.cc.step)
        if cc_step > 0:
            print(
                'Resuming training of already optimized CC model at\
                  step:', cc_step)

        label_stats = crosstab(self, report_tvd=True)

        def break_pretrain(label_stats, counter):
            c1 = counter >= self.cc.config.min_pretrain_iter
            c2 = (label_stats['tvd'] < self.cc.config.min_tvd)
            return (c1 and c2)

        for counter in trange(cc_step, cc_step + num_iter):
            #Check for early exit
            if counter % (10 * self.cc.config.log_step) == 0:
                label_stats = crosstab(self, report_tvd=True)
                print('ptstep:', counter, '  TVD:', label_stats['tvd'])
                if break_pretrain(label_stats, counter):
                    print('Completed Pretrain by TVD Qualification')
                    break

            #Optimize critic
            self.cc.critic_update(self.sess)

            #one iter causal controller
            fetch_dict = {
                "pretrain_op": self.cc.train_op,
                'cc_step': self.cc.step,
                'step': self.step,
            }

            #update what to run
            if counter % self.cc.config.log_step == 0:
                fetch_dict.update({
                    "summary": self.cc.summary_op,
                    "c_loss": self.cc.c_loss,
                    "dcc_loss": self.cc.dcc_loss,
                })
            result = self.sess.run(fetch_dict)

            #update summaries
            if counter % self.cc.config.log_step == 0:
                if counter % (10 * self.cc.config.log_step) == 0:
                    sum_tvd = make_summary('misc/tvd', label_stats['tvd'])
                    self.summary_writer.add_summary(sum_tvd, result['cc_step'])

                self.summary_writer.add_summary(result['summary'],
                                                result['cc_step'])
                self.summary_writer.flush()

                c_loss = result['c_loss']
                dcc_loss = result['dcc_loss']
                print("[{}/{}] Loss_C: {:.6f} Loss_DCC: {:.6f}".\
                      format(counter, cc_step+ num_iter, c_loss, dcc_loss))

            if counter % (10 * self.cc.config.log_step) == 0:
                self.cc.saver.save(self.sess, self.cc.save_model_name,
                                   result['cc_step'])

        else:
            label_stats = crosstab(self, report_tvd=True)
            self.cc.saver.save(self.sess, self.cc.save_model_name,
                               self.cc.step)
            print('Completed Pretrain by Exhausting all Pretrain Steps!')

        print('step:', result['cc_step'], '  TVD:', label_stats['tvd'])
예제 #9
0
 def run(self, start_new_model=False):
     if self.is_master and start_new_model:
         self.remove_training_directory(self.train_dir)
     target, device_fn = self.start_server_if_distributed()
     meta_filename = self.get_meta_filename(start_new_model, self.train_dir)
     with tf.Graph().as_default() as graph:
         if meta_filename:
             saver = self.recover_model(meta_filename)
         with tf.device(device_fn):
             if not meta_filename:
                 saver = self.build_model(self.model, self.reader)
             global_step = tf.get_collection("global_step")[0]
             loss = tf.get_collection("loss")[0]
             predictions = tf.get_collection("predictions")[0]
             labels = tf.get_collection("labels")[0]
             train_op = tf.get_collection("train_op")[0]
             init_op = tf.global_variables_initializer()
     sv = tf.train.Supervisor(graph,
                              logdir=self.train_dir,
                              init_op=init_op,
                              is_chief=self.is_master,
                              global_step=global_step,
                              save_model_secs=15 * 60,
                              save_summaries_secs=120,
                              saver=saver)
     logging.info(f"{str_task(self.task)}: Starting managed session.")
     with sv.managed_session(target, config=self.config) as sess:
         try:
             logging.info(f"{str_task(self.task)}: Entering training loop.")
             while (not sv.should_stop()) and (not self.max_steps_reached):
                 batch_start_time = time.time()
                 _, global_step_, loss_, predictions_, labels_ = sess.run(
                     [train_op, global_step, loss, predictions, labels])
                 secs_per_batch = time.time() - batch_start_time
                 if self.max_steps and self.max_steps <= global_step_:
                     self.max_steps_reached = True
                 if self.is_master:
                     examples_per_sec = labels_.shape[0] / secs_per_batch
                     hit_at_one = calculate_hit_at_one(
                         predictions_, labels_)
                     perr = calculate_perr(predictions_, labels_)
                     gap = calculate_gap(predictions_, labels_)
                     logging.info(f"{str_task(self.task)}: "\
                                  f"training step {global_step_}| "\
                                  f"Hit@1: {hit_at_one:.2f} "\
                                  f"PERR: {perr:.2f} "\
                                  f"GAP: {gap:.2f} "\
                                  f"Loss: {loss_:.5f}")
                     summary_dict = {
                         "model/Training_Hit@1": hit_at_one,
                         "model/Training_Perr": perr,
                         "model/Training_GAP": gap,
                         "global_step/Examples/Second": examples_per_sec,
                     }
                     for key, value in summary_dict.items():
                         sv.summary_writer.add_summary(
                             make_summary(key, value), global_step_)
                     sv.summary_writer.flush()
                     time_to_export = \
                             ((self.last_model_export_step == 0) or
                              (global_step_ - self.last_model_export_step \
                               >= self.export_model_steps))
                     if self.is_master and time_to_export:
                         self.export_model(global_step_, sv, sess)
                         self.last_model_export_step = global_step_
             if self.is_master:
                 self.export_model(global_step_, sv, sess)
         except tf.errors.OutOfRangeError:
             logging.info(f"{str_task(self.task)}: "\
                          "Done training -- epoch limit reached.")
     logging.info(f"{str_task(self.task)}: Exited training loop.")
     sv.Stop()