Exemplo n.º 1
0
def train_and_evaluate():
    """Train the model with custom training loop, evaluating at given intervals."""

    # Set mixed precision policy
    if FLAGS.mixed_precision:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)

    # Get dataset
    dataset = _get_dataset(dataset=FLAGS.dataset,
                           label_mode=FLAGS.label_mode,
                           input_mode=FLAGS.input_mode,
                           input_length=FLAGS.input_length,
                           seq_shift=FLAGS.seq_shift,
                           def_val=DEF_VAL)

    # Define representation
    rep = Representation(blank_index=BLANK_INDEX,
                         def_val=DEF_VAL,
                         loss_mode=FLAGS.loss_mode,
                         num_event_classes=dataset.num_event_classes(),
                         pad_val=PAD_VAL,
                         use_def=FLAGS.use_def,
                         decode_fn=FLAGS.decode_fn,
                         beam_width=FLAGS.beam_width)

    # Get model
    model = _get_model(model=FLAGS.model,
                       dataset=FLAGS.dataset,
                       num_classes=rep.get_num_classes(),
                       input_length=FLAGS.input_length,
                       l2_lambda=L2_LAMBDA)
    seq_length = model.get_seq_length()
    rep.set_seq_length(seq_length)

    # Instantiate learning rate schedule and optimizer
    if FLAGS.lr_decay_fn == "exponential":
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=FLAGS.lr_base,
            decay_steps=LR_DECAY_STEPS,
            decay_rate=FLAGS.lr_decay_rate,
            staircase=True)
    elif FLAGS.lr_decay_fn == "piecewise_constant":
        values = np.divide(FLAGS.lr_base, LR_VALUE_DIV)
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=LR_BOUNDARIES, values=values.tolist())
    elif FLAGS.lr_decay_fn == "constant":
        lr_schedule = ConstantLR(FLAGS.lr_base)
    optimizer = Adam(learning_rate=lr_schedule)
    # Get LossScaleOptimizer
    if FLAGS.mixed_precision:
        optimizer = LossScaleOptimizer(optimizer=optimizer,
                                       loss_scale='dynamic')

    # Get loss function
    train_loss_fn = rep.get_loss_fn(batch_size=FLAGS.batch_size)
    eval_loss_fn = rep.get_loss_fn(batch_size=FLAGS.eval_batch_size)

    # Get train and eval dataset
    collapse_fn = rep.get_loss_collapse_fn()
    train_dataset = dataset(batch_size=FLAGS.batch_size,
                            data_dir=FLAGS.train_dir,
                            is_predicting=False,
                            is_training=True,
                            label_fn=model.get_label_fn(FLAGS.batch_size),
                            collapse_fn=collapse_fn,
                            num_shuffle=FLAGS.num_shuffle)
    eval_dataset = dataset(batch_size=FLAGS.eval_batch_size,
                           data_dir=FLAGS.eval_dir,
                           is_predicting=False,
                           is_training=False,
                           label_fn=model.get_label_fn(FLAGS.eval_batch_size),
                           collapse_fn=collapse_fn,
                           num_shuffle=FLAGS.num_shuffle)

    # Load model
    if FLAGS.model_ckpt is not None:
        logging.info("Loading model from {}".format(FLAGS.model_ckpt))
        load_status = model.load_weights(
            os.path.join(FLAGS.model_dir, "checkpoints", FLAGS.model_ckpt))
        load_status.assert_consumed()

    # Set up log writer and metrics
    train_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.model_dir, "log/train"))
    eval_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.model_dir, "log/eval"))
    train_metrics = TrainMetrics(representation=rep, writer=train_writer)
    eval_metrics = EvalMetrics(representation=rep, writer=eval_writer)

    # Save best checkpoints in terms of f1
    model_saver = ModelSaver(os.path.join(FLAGS.model_dir, "checkpoints"),
                             compare_fn=lambda x, y: x.score > y.score,
                             sort_reverse=True)

    # Keep track of total global step
    global_step = 0

    # Iterate over epochs
    for epoch in range(FLAGS.train_epochs):
        logging.info('Starting epoch %d' % (epoch, ))

        # Iterate over training batches
        for step, (train_features, train_labels, train_labels_c,
                   train_labels_l) in enumerate(train_dataset):
            # Assert sizes
            assert train_labels.shape == [
                FLAGS.batch_size, seq_length
            ], "Labels shape [batch_size, seq_length]"
            # Run the train step
            train_logits, train_loss, train_l2_loss, train_grads = train_step(
                model, train_features, train_labels, train_labels_c,
                train_labels_l, train_loss_fn, optimizer)
            # Assert sizes
            assert train_logits.shape == [
                FLAGS.batch_size, seq_length,
                rep.get_num_classes()
            ], "Logits shape [batch_size, seq_length, num_classes]"
            # Log every FLAGS.log_steps steps.
            if global_step % FLAGS.log_steps == 0:
                logging.info("Memory used: {} GB".format(
                    psutil.virtual_memory().used / 2**30))
                # Decode logits into predictions
                train_predictions_u = None
                if FLAGS.loss_mode == "ctc":
                    train_predictions_u, _ = rep.get_decode_fn(
                        FLAGS.batch_size)(train_logits)
                    train_predictions_u = rep.get_inference_collapse_fn()(
                        train_predictions_u)
                # General logs
                logging.info('Step %s in epoch %s; global step %s' %
                             (step, epoch, global_step))
                logging.info('Seen this epoch: %s samples' %
                             ((step + 1) * FLAGS.batch_size))
                logging.info('Total loss (this step): %s' %
                             float(train_loss + train_l2_loss))
                with train_writer.as_default():
                    tf.summary.scalar("training/global_gradient_norm",
                                      data=tf.linalg.global_norm(train_grads),
                                      step=global_step)
                    tf.summary.scalar('training/loss',
                                      data=train_loss,
                                      step=global_step)
                    tf.summary.scalar('training/l2_loss',
                                      data=train_l2_loss,
                                      step=global_step)
                    tf.summary.scalar('training/total_loss',
                                      data=train_loss + train_l2_loss,
                                      step=global_step)
                    tf.summary.scalar('training/learning_rate',
                                      data=lr_schedule(epoch),
                                      step=global_step)
                # Update metrics
                train_metrics.update(train_labels, train_logits,
                                     train_predictions_u)
                # Log metrics
                train_metrics.log(global_step)
                # Save latest model
                model_saver.save_latest(model=model,
                                        step=global_step,
                                        file="model")
                # Flush TensorBoard
                train_writer.flush()

            # Evaluate every FLAGS.eval_steps steps.
            if global_step % FLAGS.eval_steps == 0:
                logging.info('Evaluating at global step %s' % global_step)
                # Keep track of eval losses
                eval_losses = []
                eval_l2_losses = []
                # Iterate through eval batches
                for i, (eval_features, eval_labels, eval_labels_c,
                        eval_labels_l) in enumerate(eval_dataset):
                    # Assert sizes
                    assert eval_labels.shape == [
                        FLAGS.eval_batch_size, seq_length
                    ], "Labels shape [batch_size, seq_length]"
                    # Run the eval step
                    eval_logits, eval_loss, eval_l2_loss = eval_step(
                        model, eval_features, eval_labels, eval_labels_c,
                        eval_labels_l, eval_loss_fn)
                    eval_losses.append(eval_loss.numpy())
                    eval_l2_losses.append(eval_l2_loss.numpy())
                    # Assert sizes
                    assert eval_logits.shape == [
                        FLAGS.eval_batch_size, seq_length,
                        rep.get_num_classes()
                    ], "Logits shape [batch_size, seq_length, num_classes]"
                    # Decode logits into predictions
                    eval_predictions_u = None
                    if FLAGS.loss_mode == "ctc":
                        eval_predictions_u, _ = rep.get_decode_fn(
                            FLAGS.eval_batch_size)(eval_logits)
                        eval_predictions_u = rep.get_inference_collapse_fn()(
                            eval_predictions_u)
                    # Update metrics for this batch
                    eval_metrics.update_i(eval_labels, eval_logits,
                                          eval_predictions_u)
                # Update mean metrics
                eval_score = eval_metrics.update()
                # General logs
                eval_loss = np.mean(eval_losses)
                eval_l2_loss = np.mean(eval_l2_losses)
                logging.info('Evaluation loss: %s' %
                             float(eval_loss + eval_l2_loss))
                with eval_writer.as_default():
                    tf.summary.scalar('training/loss',
                                      data=eval_loss,
                                      step=global_step)
                    tf.summary.scalar('training/l2_loss',
                                      data=eval_l2_loss,
                                      step=global_step)
                    tf.summary.scalar('training/total_loss',
                                      data=eval_loss + eval_l2_loss,
                                      step=global_step)
                # Log metrics
                eval_metrics.log(global_step)
                # Save best models
                model_saver.save_best(model=model,
                                      score=float(eval_score),
                                      step=global_step,
                                      file="model")
                # Flush TensorBoard
                eval_writer.flush()

            # Clean up memory
            tf.keras.backend.clear_session()
            gc.collect()

            # Increment global step
            global_step += 1

        # Save and keep latest model for every 10th epoch
        if epoch % 10 == 9:
            model_saver.save_keep(model=model, step=global_step, file="model")

        logging.info('Finished epoch %s' % (epoch, ))
        optimizer.finish_epoch()

    # Save final model
    model_saver.save_latest(model=model, step=global_step, file="model")
    # Finished training
    logging.info("Finished training")
Exemplo n.º 2
0
def predict():
    # Set mixed precision policy
    if FLAGS.mixed_precision:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
    # Make target dir
    if not os.path.exists(FLAGS.predict_dir):
        os.makedirs(FLAGS.predict_dir)
    # Get dataset
    dataset = _get_dataset(dataset=FLAGS.dataset,
                           label_mode=FLAGS.label_mode,
                           input_mode=FLAGS.input_mode,
                           input_length=FLAGS.input_length,
                           seq_shift=FLAGS.seq_shift,
                           def_val=DEF_VAL)
    num_event_classes = dataset.num_event_classes()
    # Define representation
    rep = Representation(blank_index=BLANK_INDEX,
                         def_val=DEF_VAL,
                         loss_mode=FLAGS.loss_mode,
                         num_event_classes=num_event_classes,
                         pad_val=PAD_VAL,
                         use_def=FLAGS.use_def,
                         decode_fn=FLAGS.decode_fn,
                         beam_width=FLAGS.beam_width)
    num_classes = rep.get_num_classes()
    # Get model and infer seq_length
    model = _get_model(model=FLAGS.model,
                       dataset=FLAGS.dataset,
                       num_classes=num_classes,
                       input_length=FLAGS.input_length,
                       l2_lambda=L2_LAMBDA)
    seq_length = model.get_seq_length()
    rep.set_seq_length(seq_length)
    # Make sure that seq_shift is set corresponding to model SEQ_POOL
    assert FLAGS.seq_shift == model.get_out_pool(), \
      "seq_shift should be equal to model.get_out_pool() in predict"
    # Load weights
    model.load_weights(
        os.path.join(FLAGS.model_dir, "checkpoints", FLAGS.model_ckpt))
    # Set up metrics
    metrics = PredMetrics(rep)
    # Files for predicting
    filenames = gfile.Glob(os.path.join(FLAGS.eval_dir, "*.tfrecord"))
    # For each filename, export logits
    for filename in filenames:
        # Get video id
        video_id = os.path.splitext(os.path.basename(filename))[0]
        export_csv = os.path.join(FLAGS.predict_dir, str(video_id) + ".csv")
        export_tfrecord = os.path.join(FLAGS.predict_dir, "logits",
                                       str(video_id) + ".tfrecord")
        logging.info("Working on {0}.".format(video_id))
        if os.path.exists(export_csv) and os.path.exists(export_tfrecord):
            logging.info(
                "Export files already exist. Skipping {0}.".format(filename))
            continue
        # Get the dataset
        label_fn = model.get_label_fn(1)
        collapse_fn = rep.get_loss_collapse_fn()
        data = dataset(batch_size=1,
                       data_dir=filename,
                       is_predicting=True,
                       is_training=False,
                       label_fn=label_fn,
                       collapse_fn=collapse_fn)
        # Iterate to get n and v_seq_length
        n = len(list(data))
        v_seq_length = n + seq_length - 1
        # Get the aggregators
        labels_aggregator = aggregation.ConcatAggregator(n=n,
                                                         idx=seq_length - 1)
        if seq_length == 1:
            logits_aggregator = aggregation.ConcatAggregator(n=n,
                                                             idx=seq_length -
                                                             1)
        else:
            logits_aggregator = aggregation.AverageAggregator(
                num_classes=num_classes, seq_length=seq_length)
        preds_aggregator = _get_preds_aggregator(
            predict_mode=FLAGS.predict_mode,
            n=n,
            rep=rep,
            v_seq_length=v_seq_length)
        # Iterate through batches
        # Write logits and labels to TFRecord for analysis
        if not os.path.exists(os.path.join(FLAGS.predict_dir, "logits")):
            os.makedirs(os.path.join(FLAGS.predict_dir, "logits"))
        with tf.io.TFRecordWriter(export_tfrecord) as tfrecord_writer:
            for i, (b_features, b_labels) in enumerate(data):
                # Assert sizes
                assert b_labels.shape == [1, seq_length
                                          ], "Labels shape [1, seq_length]"
                # Prediction step
                b_logits = pred_step(model, b_features)
                assert b_logits.shape == [
                    1, seq_length, rep.get_num_classes()
                ], "Logits shape [1, seq_length, num_classes]"
                # Aggregation step
                labels_aggregator.step(i, b_labels)
                logits_aggregator.step(i, b_logits)
                if preds_aggregator is not None:
                    preds_aggregator.step(i, b_logits)
                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        'example/logits': _floats_feature(
                            b_logits.numpy().ravel()),
                        'example/labels': _int64_feature(
                            b_labels.numpy().ravel())
                    }))
                tfrecord_writer.write(example.SerializeToString())
        # Get aggregated data
        labels = labels_aggregator.result()
        logits = logits_aggregator.result()
        preds = None
        if preds_aggregator is not None:
            preds = preds_aggregator.result()
        # Collapse on video level
        if preds is not None:
            preds = rep.get_inference_collapse_fn(v_seq_length)(preds)
        # Remove empty batch dimensions
        labels = tf.squeeze(labels, axis=0)
        logits = tf.squeeze(logits, axis=0)
        if preds is not None:
            preds = tf.squeeze(preds, axis=0)
        # Export probs for two stage model
        ids = [video_id] * v_seq_length
        if FLAGS.predict_mode == "probs":
            logging.info("Saving labels and probs")
            probs = tf.nn.softmax(logits, axis=-1)
            save_array = np.column_stack(
                (ids, labels.numpy().tolist(), probs.numpy().tolist()))
            np.savetxt(export_csv, save_array, delimiter=",", fmt='%s')
            continue
        # Update metrics for single stage model
        metrics.update(labels, preds)
        # Save for single stage model
        logging.info("Writing {0} examples to {1}.csv...".format(
            len(ids), video_id))
        save_array = np.column_stack(
            (ids, labels.numpy().tolist(), logits.numpy().tolist(),
             preds.numpy().tolist()))
        np.savetxt(export_csv, save_array, delimiter=",", fmt='%s')
    if FLAGS.predict_mode == "probs":
        # Finish
        exit()
    # Print metrics
    metrics.finish()
Exemplo n.º 3
0
def main(arg=None):
    # Make target dir
    export_dir = os.path.join(FLAGS.predict_dir,
                              "beam_width_" + str(FLAGS.beam_width))
    if not os.path.exists(export_dir):
        os.makedirs(export_dir)
    # Get representation and metrics
    seq_length = FLAGS.seq_length
    num_classes = FLAGS.num_classes
    rep = Representation(blank_index=BLANK_INDEX,
                         def_val=DEF_VAL,
                         loss_mode=None,
                         num_event_classes=num_classes - 1,
                         pad_val=PAD_VAL,
                         use_def=False,
                         decode_fn=FLAGS.decode_fn,
                         beam_width=FLAGS.beam_width)
    rep.set_seq_length(seq_length)
    metrics = PredMetrics(rep)
    # Find files
    filenames = sorted(gfile.Glob(os.path.join(FLAGS.logits_dir,
                                               "*.tfrecord")))
    # For each file
    for filename in filenames:
        # Get video id
        video_id = os.path.splitext(os.path.basename(filename))[0]
        export_csv = os.path.join(FLAGS.predict_dir,
                                  "beam_width_" + str(FLAGS.beam_width),
                                  str(video_id) + ".csv")
        logging.info("Working on {0}.".format(video_id))
        # Get data information
        data = tf.data.TFRecordDataset(filename)
        n = len(list(data))
        v_seq_length = n + seq_length - 1
        # Get the aggregators
        labels_aggregator = aggregation.ConcatAggregator(n=n,
                                                         idx=seq_length - 1)
        logits_aggregator = aggregation.AverageAggregator(
            num_classes=num_classes, seq_length=seq_length)
        decode_fn = rep.get_decode_fn(1)
        preds_aggregator = aggregation.BatchLevelVotedPredsAggregator(
            num_classes=num_classes,
            seq_length=seq_length,
            def_val=DEF_VAL,
            decode_fn=decode_fn)
        # Iterate through batches
        for i, batch_data in enumerate(data):
            b_logits, b_labels = parse(batch_data)
            # Aggregation step
            labels_aggregator.step(i, b_labels)
            logits_aggregator.step(i, b_logits)
            preds_aggregator.step(i, b_logits)
        # Get aggregated data
        labels = labels_aggregator.result()
        logits = logits_aggregator.result()
        preds = preds_aggregator.result()
        # Collapse on video level
        preds = rep.get_inference_collapse_fn(v_seq_length)(preds)
        # Remove empty batch dimensions
        labels = tf.squeeze(labels, axis=0)
        logits = tf.squeeze(logits, axis=0)
        preds = tf.squeeze(preds, axis=0)
        # Update metrics for single stage model
        metrics.update(labels, preds)
        # Save
        ids = [video_id] * v_seq_length
        logging.info("Writing {0} examples to {1}.csv...".format(
            len(ids), video_id))
        save_array = np.column_stack(
            (ids, labels.numpy().tolist(), logits.numpy().tolist(),
             preds.numpy().tolist()))
        np.savetxt(export_csv, save_array, delimiter=",", fmt='%s')
    # Print metrics
    metrics.finish()