Exemplo n.º 1
0
def train(model, data_loader):
    # Setting up model
    training_iterator = data_loader(cfg.batch_size, mode="train")
    validation_iterator = data_loader(cfg.batch_size, mode="eval")
    inputs = data_loader.next_element["images"]
    labels = data_loader.next_element["labels"]
    model.create_network(inputs, labels)

    loss, train_ops, summary_ops = model.train(cfg.num_gpus)

    # Creating files, saver and summary writer to save training results
    fd = save_to(is_training=True)
    summary_writer = tf.summary.FileWriter(cfg.logdir)
    summary_writer.add_graph(tf.get_default_graph())
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=10)

    # Setting up training session
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        print("\nNote: all of results will be saved to directory: " +
              cfg.results_dir)
        for step in range(1, cfg.num_steps):
            start_time = time.time()
            if step % cfg.train_sum_every == 0:
                _, loss_val, train_acc, summary_str = sess.run(
                    [train_ops, loss, model.accuracy, summary_ops],
                    feed_dict={data_loader.handle: training_handle})
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                out_path = os.path.join(cfg.results_dir,
                                        "timelines/timeline_%d.json" % step)
                with open(out_path, "w") as f:
                    f.write(ctf)
                summary_writer.add_summary(summary_str, step)
                fd["loss"].write("{:d},{:.4f}\n".format(step, loss_val))
                fd["loss"].flush()
                fd["train_acc"].write("{:d},{:.4f}\n".format(step, train_acc))
                fd["train_acc"].flush()
            else:
                _, loss_val = sess.run(
                    [train_ops, loss],
                    feed_dict={data_loader.handle: training_handle})
                # assert not np.isnan(loss_val), 'Something wrong! loss is nan...'

            if step % cfg.val_sum_every == 0:
                print("evaluating, it will take a while...")
                sess.run(validation_iterator.initializer)
                probs = []
                targets = []
                total_acc = 0
                n = 0
                while True:
                    try:
                        val_acc, prob, label = sess.run(
                            [model.accuracy, model.probs, labels],
                            feed_dict={data_loader.handle: validation_handle})
                        probs.append(prob)
                        targets.append(label)
                        total_acc += val_acc
                        n += 1
                    except tf.errors.OutOfRangeError:
                        break
                probs = np.concatenate(probs, axis=0)
                targets = np.concatenate(targets, axis=0).reshape((-1, 1))
                avg_acc = total_acc / n
                path = os.path.join(
                    os.path.join(cfg.results_dir, "activations"))
                plot_activation(np.hstack((probs, targets)),
                                step=step,
                                save_to=path)
                fd["val_acc"].write("{:d},{:.4f}\n".format(step, avg_acc))
                fd["val_acc"].flush()
            if step % cfg.save_ckpt_every == 0:
                saver.save(sess,
                           save_path=os.path.join(cfg.logdir, 'model.ckpt'),
                           global_step=step)

            duration = time.time() - start_time
            log_str = ' step: {:d}, loss: {:.3f}, time: {:.3f} sec/step' \
                      .format(step, loss_val, duration)
            print(log_str)
Exemplo n.º 2
0
def train(model, data_loader):
    checkpoint_path = os.path.join(cfg.logdir, 'model_-{epoch:04d}.ckpt')

    # Setting up the dataloader
    training_iterator = data_loader(cfg.batch_size, mode="train")
    validation_iterator = data_loader(cfg.batch_size, mode="eval")

    # Creating files, saver and summary writer to save training results
    fd = save_to(is_training=True)
    summary_writer = tf.summary.create_file_writer(cfg.logdir)
    print("\nNote: all of results will be saved to directory: " +
          cfg.results_dir)

    # Compute the cardinality of the dataset
    train_ds_len = 0
    for data in training_iterator:
        train_ds_len = train_ds_len + 1

    val_ds_len = 0
    for _ in validation_iterator:
        val_ds_len = val_ds_len + 1

    # Init the model and show the summary
    data = next(iter(training_iterator))
    model(data['images'], data['labels'], 1)
    print(model.summary())

    # Train the model
    with summary_writer.as_default():
        loss_val_avg = []
        train_acc_avg = []
        for step in range(1, cfg.num_epochs):
            start_time = time.time()

            # Initialize progress bars
            progbar_train = tf.keras.utils.Progbar(train_ds_len)
            progbar_val = tf.keras.utils.Progbar(val_ds_len)

            tf.summary.experimental.set_step(step)

            # Train
            loss_val_step = []
            train_acc_step = []
            for b_id, data in enumerate(training_iterator):
                loss_val, train_acc, _ = model(data['images'], data['labels'],
                                               (step - 1) * train_ds_len +
                                               b_id)
                loss_val_step.append(loss_val)
                train_acc_step.append(train_acc)
                progbar_train.update(b_id + 1,
                                     values=[('loss', loss_val),
                                             ('accuracy', train_acc)])

            loss_val_avg.append(sum(loss_val_step) / len(loss_val_step))
            train_acc_avg.append(sum(train_acc_step) / len(train_acc_step))

            loss_val = loss_val_avg[-1]
            train_acc = train_acc_avg[-1]

            if step % cfg.train_sum_every == 0:
                summary_writer.flush()

                fd["loss"].write("{:d},{:.4f}\n".format(step, loss_val))
                fd["loss"].flush()
                fd["train_acc"].write("{:d},{:.4f}\n".format(step, train_acc))
                fd["train_acc"].flush()

            if step % cfg.val_sum_every == 0:
                print("evaluating, it will take a while...")
                probs = []
                targets = []
                total_acc = 0
                n = 0
                for b_id, data in enumerate(validation_iterator):
                    prob, _, val_acc = model.eval(data['images'],
                                                  data['labels'])
                    probs.append(prob)
                    targets.append(data['labels'])
                    total_acc += val_acc
                    n += 1
                    progbar_val.update(b_id + 1,
                                       values=[('accuracy', val_acc)])

                probs = np.concatenate(probs, axis=0)
                targets = np.concatenate(targets, axis=0).reshape((-1, 1))
                avg_acc = total_acc / n
                path = os.path.join(
                    os.path.join(cfg.results_dir, "activations"))
                plot_activation(np.hstack((probs, targets)),
                                step=step,
                                save_to=path)
                fd["val_acc"].write("{:d},{:.4f}\n".format(step, avg_acc))
                fd["val_acc"].flush()

            if step % cfg.save_ckpt_every == 0:
                model.save_weights(checkpoint_path.format(epoch=0))

            duration = time.time() - start_time
            log_str = ' step: {:d}, loss: {:.3f}, accuracy: {:.3f}, time: {:.3f} sec/step'.format(
                step, loss_val, train_acc, duration)
            print(log_str)