Exemplo n.º 1
0
def validation_run(valid, filepath, i, epoch, first_run, opts):
    if filepath:
        valid.saver.restore(valid.session, filepath)

    # Gather accuracy statistics
    accuracy = 0.0
    start = time.time()
    for __ in range(opts["validation_iterations"]):
        try:
            a = valid.session.run(valid.ops)[0]
        except tf.errors.OpError as e:
            raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message)

        accuracy += a
    val_time = time.time() - start
    accuracy /= opts["validation_iterations"]

    valid_format = (
        "Validation top-1 accuracy (iteration: {iteration:6d}, epoch: {epoch:6.2f}, img/sec: {img_per_sec:6.2f},"
        " time: {val_time:8.6f}): {val_acc:6.3f}%")

    stats = OrderedDict([
        ('iteration', i),
        ('epoch', epoch),
        ('val_acc', accuracy),
        ('val_time', val_time),
        ('img_per_sec',
         (opts["validation_iterations"] * opts["validation_batches_per_step"] *
          opts['validation_total_batch_size']) / val_time),
    ])
    logging.print_to_file_and_screen(valid_format.format(**stats), opts)
    logging.write_to_csv(stats, first_run, False, opts)
Exemplo n.º 2
0
def train_process(model, LR_Class, opts):

    # --------------- OPTIONS ---------------------
    epochs = opts["epochs"]
    iterations_per_epoch = DATASET_CONSTANTS[
        opts['dataset']]['NUM_IMAGES'] // opts["total_batch_size"]
    if not opts['iterations']:
        iterations = epochs * iterations_per_epoch
        log_freq = iterations_per_epoch // opts['logs_per_epoch']
    else:
        iterations = opts['iterations']
        log_freq = opts['log_freq']

    if log_freq < opts['batches_per_step']:
        iterations_per_step = log_freq
    else:
        iterations_per_step = log_freq // int(
            round(log_freq / opts['batches_per_step']))

    iterations_per_valid = iterations_per_epoch
    iterations_per_ckpt = iterations_per_epoch // opts[
        'ckpts_per_epoch'] if opts['ckpts_per_epoch'] else np.inf

    LR = LR_Class(opts, iterations)

    batch_accs = deque(maxlen=iterations_per_epoch // iterations_per_step)
    batch_losses = deque(maxlen=iterations_per_epoch // iterations_per_step)
    batch_times = deque(maxlen=iterations_per_epoch // iterations_per_step)
    start_all = None

    # -------------- BUILD TRAINING GRAPH ----------------

    train = training_graph(
        model, opts, iterations_per_step * opts["gradients_to_accumulate"])
    train.session.run(train.init)
    train.session.run(train.iterator.initializer)

    # -------------- BUILD VALIDATION GRAPH ----------------

    if opts['validation']:
        valid = validation.initialise_validation(model, opts)

    # -------------- SAVE AND RESTORE --------------

    if opts['ckpts_per_epoch']:
        filepath = train.saver.save(train.session,
                                    opts["checkpoint_path"],
                                    global_step=0)
        print("Saved checkpoint to {}".format(filepath))

    if opts.get('restoring'):
        filename_pattern = re.compile(".*ckpt-[0-9]+$")
        ckpt_pattern = re.compile(".*ckpt-([0-9]+)$")
        filenames = sorted(
            [
                os.path.join(opts['logs_path'], f[:-len(".index")])
                for f in os.listdir(opts['logs_path'])
                if filename_pattern.match(f[:-len(".index")])
                and f[-len(".index"):] == ".index"
            ],
            key=lambda x: int(ckpt_pattern.match(x).groups()[0]))
        latest_checkpoint = filenames[-1]
        logging.print_to_file_and_screen(
            "Restoring training from latest checkpoint: {}".format(
                latest_checkpoint), opts)
        ckpt_pattern = re.compile(".*ckpt-([0-9]+)$")
        i = int(ckpt_pattern.match(latest_checkpoint).groups()[0]) + 1
        train.saver.restore(train.session, latest_checkpoint)
        epoch = float(opts["total_batch_size"] *
                      (i + iterations_per_step)) / DATASET_CONSTANTS[
                          opts['dataset']]['NUM_IMAGES']
    else:
        i = 0

    # ------------- TRAINING LOOP ----------------

    print_format = (
        "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.2f}, lr: {lr:6.4g}, loss: {loss_avg:6.3f}, accuracy: {train_acc_avg:6.3f}%"
        ", img/sec: {img_per_sec:6.2f}, time: {it_time:8.6f}, total_time: {total_time:8.1f}"
    )

    step = 0
    start_all = time.time()
    while i < iterations:
        step += opts["gradients_to_accumulate"]
        log_this_step = ((i // log_freq) <
                         ((i + iterations_per_step) // log_freq) or (i == 0)
                         or ((i + (2 * iterations_per_step)) >= iterations))
        ckpt_this_step = ((i // iterations_per_ckpt) <
                          ((i + iterations_per_step) // iterations_per_ckpt)
                          or (i == 0)
                          or ((i + (2 * iterations_per_step)) >= iterations))
        valid_this_step = (opts['validation'] and (
            (i // iterations_per_valid) <
            ((i + iterations_per_step) // iterations_per_valid) or (i == 0) or
            ((i + (2 * iterations_per_step)) >= iterations)))

        # Run Training
        try:
            batch_loss, batch_acc, batch_time, current_lr, scaled_lr = training_step(
                train, i + 1, LR.feed_dict_lr(i))
        except tf.errors.OpError as e:
            raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message)

        batch_time /= iterations_per_step

        # Calculate Stats
        batch_accs.append([batch_acc])
        batch_losses.append([batch_loss])

        if i != 0:
            batch_times.append([batch_time])

        # Print loss
        if log_this_step:
            train_acc = np.mean(batch_accs)
            train_loss = np.mean(batch_losses)

            if len(batch_times) != 0:
                avg_batch_time = np.mean(batch_times)
            else:
                avg_batch_time = batch_time

            # flush times every time it is reported
            batch_times.clear()

            total_time = time.time() - start_all
            epoch = float(opts["total_batch_size"] *
                          (i + iterations_per_step)) / DATASET_CONSTANTS[
                              opts['dataset']]['NUM_IMAGES']

            stats = OrderedDict([
                ('step', step),
                ('iteration', i + iterations_per_step),
                ('epoch', epoch),
                ('lr', current_lr),
                ('scaled_lr', scaled_lr),
                ('loss_batch', batch_loss),
                ('loss_avg', train_loss),
                ('train_acc_batch', batch_acc),
                ('train_acc_avg', train_acc),
                ('it_time', avg_batch_time),
                ('img_per_sec', opts['total_batch_size'] / avg_batch_time),
                ('total_time', total_time),
            ])

            logging.print_to_file_and_screen(print_format.format(**stats),
                                             opts)
            logging.write_to_csv(stats, i == 0, True, opts)

        if ckpt_this_step:
            filepath = train.saver.save(train.session,
                                        opts["checkpoint_path"],
                                        global_step=i + iterations_per_step)
            print("Saved checkpoint to {}".format(filepath))

        # Eval
        if valid_this_step and opts['validation']:
            if 'validation_points' not in locals():
                validation_points = []
            validation_points.append(
                (i + iterations_per_step, epoch, i == 0, filepath))

        i += iterations_per_step

    # ------------ RUN VALIDATION ------------
    if opts['validation']:
        for iteration, epoch, first_run, filepath in validation_points:
            validation.validation_run(valid, filepath, iteration, epoch,
                                      first_run, opts)

    # --------------- CLEANUP ----------------
    train.session.close()
Exemplo n.º 3
0
def inference_run(exec_filename, ckpt_name, iteration, epoch, first_run, opts):
    """Run inference for multiple iterations and collect latency values."""
    logging.mlperf_logging(key="EVAL_START",
                           log_type="start",
                           metadata={"epoch_num": round(epoch)})
    engine_name = "my_engine"
    ctx = embedded_runtime.embedded_runtime_start(exec_filename, [],
                                                  engine_name,
                                                  timeout=1000)

    input_placeholder = tf.placeholder(
        tf.uint8,
        (opts['micro_batch_size'], opts['image_size'], opts['image_size'], 3))

    num_iters = opts['iterations']
    if opts['generated_data']:
        placeholders = [input_placeholder]
        images = np.random.normal(size=(opts['micro_batch_size'],
                                        opts['image_size'], opts['image_size'],
                                        3)).astype(np.uint8)
        labels = None
    else:
        label_placeholder = tf.placeholder(tf.int32,
                                           (opts['micro_batch_size']))
        placeholders = [input_placeholder, label_placeholder]

        with tf.Graph().as_default():
            inference_dataset = dataset.data(
                opts, is_training=False).map(lambda x: {'data_dict': x})
            images, labels = dataset_to_list(
                inference_dataset, num_iters * opts['micro_batch_size'])

    call_result = embedded_runtime.embedded_runtime_call(placeholders, ctx)

    ipu.config.reset_ipu_configuration()
    gc.collect()

    thread_queue = Queue()
    with tf.Session() as session:
        # do not include time of the first iteration in stats
        initial_feed_dict = prepare_feed_dict(placeholders, images, labels,
                                              opts['micro_batch_size'],
                                              opts['generated_data'], 0)
        session.run(call_result, initial_feed_dict)

        def runner(session, thread_idx):
            thread_channel = pvti.createTraceChannel(f"Thread {thread_idx}")
            latencies = []
            accuracies = []
            for iter_idx in range(num_iters):
                feed_dict = prepare_feed_dict(placeholders, images, labels,
                                              opts['micro_batch_size'],
                                              opts['generated_data'], iter_idx)
                with pvti.Tracepoint(thread_channel, f"Iteration {iter_idx}"):
                    start_iter = time.time()
                    _, predictions = session.run(call_result, feed_dict)
                    end_iter = time.time()
                latencies.append(end_iter - start_iter)
                if not opts['generated_data']:
                    expected = feed_dict[label_placeholder]
                    accuracy = np.mean(
                        np.equal(predictions, expected).astype(np.float32))
                    accuracies.append(accuracy)
            thread_queue.put((latencies, accuracies), timeout=10)

        thp = [
            Thread(target=runner, args=(session, thread_idx))
            for thread_idx in range(opts['num_inference_thread'])
        ]
        inference_start = time.time()
        for idx, _thread in enumerate(thp):
            _thread.start()
            print(f"Thread {idx} started")

        for idx, _thread in enumerate(thp):
            _thread.join()
            print(f"Thread {idx} joined")
        val_time = time.time() - inference_start

    latencies, accuracies = [], []
    while not thread_queue.empty():
        lat_acc = thread_queue.get()
        latencies.extend(lat_acc[0])
        accuracies.extend(lat_acc[1])

    if opts['generated_data']:
        total_accuracy = -1
    else:
        total_accuracy = sum(accuracies) / len(accuracies)
        total_accuracy *= 100

    # convert latencies to miliseconds
    latencies = [1000 * latency_s for latency_s in latencies]

    max_latency = max(latencies)
    mean_latency = np.mean(latencies)
    perc_99 = np.percentile(latencies, 99)
    perc_99_9 = np.percentile(latencies, 99.9)

    print(
        f"Latencies - avg: {mean_latency:8.4f}, 99th percentile: {perc_99:8.4f}, "
        f"99.9th percentile: {perc_99_9:8.4f}, max: {max_latency:8.4f}")

    valid_format = (
        "Validation top-1 accuracy [{name}] (iteration: {iteration:6d}, epoch: {epoch:6.2f}, "
        "img/sec: {img_per_sec:6.2f}, time: {val_time:8.6f}, "
        "latency (ms): {latency:8.4f}: {val_acc:6.3f}%")

    val_size = (num_iters * opts['num_inference_thread'] *
                opts['validation_total_batch_size'])

    stats = OrderedDict([
        ('name', ckpt_name),
        ('iteration', iteration),
        ('epoch', epoch),
        ('val_acc', total_accuracy),
        ('val_time', val_time),
        ('val_size', val_size),
        ('img_per_sec', val_size / val_time),
        ('latency', mean_latency),
    ])
    logging.print_to_file_and_screen(valid_format.format(**stats), opts)
    logging.write_to_csv(stats, first_run, False, opts)
    if opts['wandb'] and opts['distributed_worker_index'] == 0:
        logging.log_to_wandb(stats)
    logging.mlperf_logging(key="EVAL_STOP",
                           log_type="stop",
                           metadata={"epoch_num": round(epoch)})
    logging.mlperf_logging(key="EVAL_ACCURACY",
                           value=float(stats['val_acc']) / 100,
                           metadata={"epoch_num": round(epoch)})
    return stats
Exemplo n.º 4
0
def evaluate(opts):
    epochs = opts["epochs"]
    total_samples = dataset.get_dataset_files_count(opts, is_training=True)
    logger.info("[evaluation] Total samples with duplications {}".format(
        total_samples))
    total_independent_samples = total_samples // opts['duplication_factor']
    logger.info("[evaluation] Total samples without duplications {}".format(
        total_independent_samples))
    steps_per_epoch = total_independent_samples // (opts['batches_per_step'] *
                                                    opts["total_batch_size"])
    iterations_per_epoch = total_independent_samples // (
        opts["total_batch_size"])

    # total iterations
    if opts['steps']:
        logger.warn(
            "[evaluation] Ignoring the epoch flag and using the steps one")
        steps = opts['steps']
    else:
        steps = epochs * steps_per_epoch
    logger.info(
        "[evaluation] Total training steps equal to {}, total number of samples being analyzed equal to {}"
        .format(steps,
                steps * opts['batches_per_step'] * opts['total_batch_size']))
    iterations_per_step = opts['batches_per_step']
    ckpt_per_step = opts['steps_per_ckpts']

    logger.info(
        "################################################################################"
    )
    logger.info("Start evaluation......")
    print_format = (
        "[evaluation] step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, "
        "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}"
    )

    # avoid nan issue caused by queue length is zero.
    queue_len = iterations_per_epoch // iterations_per_step
    if queue_len == 0:
        queue_len = 1
    batch_times = deque(maxlen=queue_len)

    # best_saver = train.saver["best_saver"]
    iterations_per_step = opts['batches_per_step']
    evals = build_graph(bert_config,
                        opts,
                        iterations_per_step,
                        is_training=False,
                        feed_name="trainfeed")
    evals.session.run(evals.init)
    evals.session.run(evals.iterator.initializer)
    evals_saver = evals.saver["train_saver"]
    evals_saver.restore(
        evals.session,
        "/localdata/yongxiy/Desktop/examples-ipu/applications/tensorflow/bert/checkpoint/phase1/BERT_pretraining_2021-03-15 08:49:29.404/"
        + f'ckpt_last-{100}')

    step = 0
    i = 0
    start_all = time.time()

    while step < steps:

        try:
            batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = eval_step(
                evals, 1.0)
        except tf.errors.OpError as e:
            raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message)

        epoch = float(opts["total_batch_size"] * i) / total_independent_samples

        batch_time /= iterations_per_step

        if step != 0:
            batch_times.append([batch_time])

        if step == 1:
            poplar_compile_time = time.time() - start_all
            logger.info(
                f"[evaluation] the poplar compile time {poplar_compile_time}")

        # Print loss
        if step % opts['steps_per_logs'] == 0:
            if len(batch_times) != 0:
                avg_batch_time = np.mean(batch_times)
            else:
                avg_batch_time = batch_time

            samples_per_sec = opts['total_batch_size'] / avg_batch_time

            # flush times every time it is reported
            batch_times.clear()

            total_time = time.time() - start_all

            stats = OrderedDict([
                ('step', step),
                ('iteration', i),
                ('epoch', epoch),
                ('lr', 1.0),
                ('mlm_loss', mlm_loss),
                ('nsp_loss', nsp_loss),
                ('mlm_acc', mlm_acc),
                ('nsp_acc', nsp_acc),
                ('iter_time', avg_batch_time),
                ('samples_per_sec', samples_per_sec),
                ('total_time', total_time),
            ])

            logger.info(print_format.format(**stats))
            bert_logging.write_to_csv(stats, i == 0, True, opts['logs_path'])

            logger.info(
                f"[evaluation] throughput samples per second: {samples_per_sec}"
            )
            logger.info(f"[evaluation] average batch time: {avg_batch_time}")
            # sys_summary = tf.Summary()
            # sys_summary.value.add(tag='perf/throughput_samples_per_second', simple_value=samples_per_sec)
            # sys_summary.value.add(tag='perf/average_batch_time', simple_value=avg_batch_time)
            # summary_writer.add_summary(sys_summary, step)

        i += iterations_per_step
        step += 1
Exemplo n.º 5
0
def train(bert_config, opts):
    # --------------- OPTIONS ---------------------
    epochs = opts["epochs"]
    total_samples = dataset.get_dataset_files_count(opts, is_training=True)
    logger.info("Total samples with duplications {}".format(total_samples))
    total_independent_samples = total_samples // opts['duplication_factor']
    logger.info("Total samples without duplications {}".format(
        total_independent_samples))
    steps_per_epoch = total_independent_samples // (opts['batches_per_step'] *
                                                    opts["total_batch_size"])
    iterations_per_epoch = total_independent_samples // (
        opts["total_batch_size"])

    # total iterations
    if opts['steps']:
        logger.warn("Ignoring the epoch flag and using the steps one")
        steps = opts['steps']
    else:
        steps = epochs * steps_per_epoch
    logger.info(
        "Total training steps equal to {}, total number of samples being analyzed equal to {}"
        .format(steps,
                steps * opts['batches_per_step'] * opts['total_batch_size']))
    iterations_per_step = opts['batches_per_step']
    ckpt_per_step = opts['steps_per_ckpts']

    # avoid nan issue caused by queue length is zero.
    queue_len = iterations_per_epoch // iterations_per_step
    if queue_len == 0:
        queue_len = 1
    batch_times = deque(maxlen=queue_len)

    # learning rate strategy
    lr_schedule_name = opts['lr_schedule']
    logger.info(f"Using learning rate schedule {lr_schedule_name}")
    LR = make_lr_schedule(lr_schedule_name, opts, steps)

    if opts['do_train']:
        # -------------- BUILD TRAINING GRAPH ----------------

        train = build_graph(bert_config,
                            opts,
                            iterations_per_step,
                            is_training=True,
                            feed_name="trainfeed")
        train.session.run(train.init)
        train.session.run(train.iterator.initializer)

        step = 0
        i = 0

        if opts['restore_path'] is not None:
            if os.path.isdir(opts['restore_path']):
                ckpt_file_path = tf.train.latest_checkpoint(
                    opts['restore_path'])
                logger.info(f"Restoring training from latest checkpoint")
            else:
                # Assume it's a directory
                ckpt_file_path = opts['restore_path']

            logger.info(
                f"Restoring training from checkpoint: {ckpt_file_path}")
            train.restore.restore(train.session, ckpt_file_path)

            ckpt_pattern = re.compile(".*ckpt-([0-9]+)$")
            i = int(ckpt_pattern.match(ckpt_file_path).groups()[0])
            step = int(i // iterations_per_step)

        if opts['start_from_ckpt']:
            # We use a checkpoint to initialise our model
            train.restore.restore(train.session, opts['start_from_ckpt'])
            logger.info("Starting the training from the checkpoint {}".format(
                opts['start_from_ckpt']))

        # Initialise Weights & Biases if available
        if opts['wandb']:
            import wandb
            wandb.init(project="tf-bert", sync_tensorboard=True)
            wandb.config.update(opts)

        # Tensorboard logs path
        log_path = os.path.join(opts["logs_path"], 'event')
        logger.info("Tensorboard event file path {}".format(log_path))
        summary_writer = tf.summary.FileWriter(log_path,
                                               train.graph,
                                               session=train.session)

        # ------------- TRAINING LOOP ----------------
        logger.info(
            "################################################################################"
        )
        logger.info("Start training......")
        print_format = (
            "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, "
            "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}"
        )

        start_all = time.time()

        train_saver = train.saver["train_saver"]
        best_saver = train.saver["best_saver"]
        # We initialize the best loss to a super large value
        best_total_loss = 1e10
        best_step = 0

        while step < steps:
            # Run Training
            learning_rate = LR.feed_dict_lr(step)
            try:
                batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step(
                    train, learning_rate)
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            epoch = float(
                opts["total_batch_size"] * i) / total_independent_samples

            batch_time /= iterations_per_step

            if step != 0:
                batch_times.append([batch_time])

            if step == 1:
                poplar_compile_time = time.time() - start_all
                poplar_summary = tf.Summary()
                poplar_summary.value.add(tag='poplar/compile_time',
                                         simple_value=poplar_compile_time)
                summary_writer.add_summary(poplar_summary)

            # Print loss
            if step % opts['steps_per_logs'] == 0:
                if len(batch_times) != 0:
                    avg_batch_time = np.mean(batch_times)
                else:
                    avg_batch_time = batch_time

                samples_per_sec = opts['total_batch_size'] / avg_batch_time

                # flush times every time it is reported
                batch_times.clear()

                total_time = time.time() - start_all

                stats = OrderedDict([
                    ('step', step),
                    ('iteration', i),
                    ('epoch', epoch),
                    ('lr', learning_rate),
                    ('mlm_loss', mlm_loss),
                    ('nsp_loss', nsp_loss),
                    ('mlm_acc', mlm_acc),
                    ('nsp_acc', nsp_acc),
                    ('iter_time', avg_batch_time),
                    ('samples_per_sec', samples_per_sec),
                    ('total_time', total_time),
                ])

                logger.info(print_format.format(**stats))
                bert_logging.write_to_csv(stats, i == 0, True,
                                          opts['logs_path'])

                sys_summary = tf.Summary()
                sys_summary.value.add(tag='perf/throughput_samples_per_second',
                                      simple_value=samples_per_sec)
                sys_summary.value.add(tag='perf/average_batch_time',
                                      simple_value=avg_batch_time)
                summary_writer.add_summary(sys_summary, step)

            # Log training statistics
            train_summary = tf.Summary()
            train_summary.value.add(tag='epoch', simple_value=epoch)
            train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss)
            train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss)
            train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc)
            train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc)
            train_summary.value.add(tag='defaultLearningRate',
                                    simple_value=learning_rate)
            train_summary.value.add(tag='samples',
                                    simple_value=step *
                                    opts['batches_per_step'] *
                                    opts['total_batch_size'])
            summary_writer.add_summary(train_summary, step)
            summary_writer.flush()

            if step % ckpt_per_step == 0 and step:
                filepath = train_saver.save(train.session,
                                            save_path=opts["checkpoint_path"],
                                            global_step=step)
                logger.info("Saved checkpoint to {}".format(filepath))

                if not opts['wandb']:
                    bert_logging.save_model_statistics(filepath,
                                                       summary_writer, step)

            # Mechanism to checkpoint the best model.
            # set opts["best_ckpt_min_steps"] to 0 to disable
            if best_total_loss > mlm_loss + nsp_loss and step - best_step > opts[
                    "best_ckpt_min_steps"] and opts["best_ckpt_min_steps"]:
                best_total_loss = mlm_loss + nsp_loss
                best_step = step
                filepath = best_saver.save(train.session,
                                           save_path=opts["checkpoint_path"] +
                                           '_best',
                                           global_step=step)
                logger.info("Saved Best checkpoint to {}".format(filepath))

            i += iterations_per_step
            step += 1

        # --------------- LAST CHECKPOINT ----------------
        filepath = train_saver.save(train.session,
                                    save_path=opts["checkpoint_path"] +
                                    '_last',
                                    global_step=step)
        logger.info("Final model saved to to {}".format(filepath))

        # --------------- CLEANUP ----------------
        train.session.close()
Exemplo n.º 6
0
def validation_run(valid, filepath, i, epoch, first_run, opts, latency_thread):
    run = True
    if filepath:
        valid.saver.restore(valid.session, filepath)
        name = filepath.split('/')[-1]

        csv_path = os.path.join(opts['logs_path'], 'validation.csv')
        if os.path.exists(csv_path):
            with open(csv_path, 'rU') as infile:
                # read the file as a dictionary for each row ({header : value})
                reader = csv.DictReader(infile)
                for row in reader:
                    if row['name'] == name:
                        run = False
                        print(
                            'Skipping validation run on checkpoint: {}'.format(
                                name))
                        break
    else:
        name = None

    if run:
        if opts['use_popdist']:
            # synchronise the model weights across all instances
            valid.session.run(valid.ops['broadcast_weights'])

        logging.mlperf_logging(key="EVAL_START",
                               log_type="start",
                               metadata={"epoch_num": round(epoch)})
        # Gather accuracy statistics
        accuracy = 0.0

        # start latency thread
        latency_thread.start()

        start = relative_timer.now()
        for __ in range(opts["validation_iterations"]):
            try:
                a = valid.session.run(valid.ops['accuracy'])
            except tf.errors.OpError as e:
                if opts['compile_only'] and 'compilation only' in e.message:
                    print("Validation graph successfully compiled")
                    print("Exiting...")
                    sys.exit(0)
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            accuracy += a
        val_time = relative_timer.now() - start
        accuracy /= opts["validation_iterations"]

        # wait for all dequeues and latency computation
        latency_thread.join()
        latency = latency_thread.get_latency()

        valid_format = (
            "Validation top-1 accuracy [{name}] (iteration: {iteration:6d}, epoch: {epoch:6.2f}, img/sec: {img_per_sec:6.2f},"
            " time: {val_time:8.6f}, latency (ms): {latency:8.4f}): {val_acc:6.3f}%"
        )

        val_size = (opts["validation_iterations"] *
                    opts["validation_batches_per_step"] *
                    opts["validation_global_batch_size"])

        count = int(
            DATASET_CONSTANTS[opts['dataset']]['NUM_VALIDATION_IMAGES'])

        raw_accuracy = accuracy
        if count < val_size:
            accuracy = accuracy * val_size / count

        stats = OrderedDict([
            ('name', name),
            ('iteration', i),
            ('epoch', epoch),
            ('val_acc', accuracy),
            ('raw_acc', raw_accuracy),
            ('val_time', val_time),
            ('val_size', val_size),
            ('img_per_sec', val_size / val_time),
            ('latency', latency * 1000),
        ])
        logging.print_to_file_and_screen(valid_format.format(**stats), opts)
        logging.write_to_csv(stats, first_run, False, opts)
        if opts["wandb"] and opts["distributed_worker_index"] == 0:
            logging.log_to_wandb(stats)
        logging.mlperf_logging(key="EVAL_STOP",
                               log_type="stop",
                               metadata={"epoch_num": round(epoch)})
        logging.mlperf_logging(key="EVAL_ACCURACY",
                               value=float(stats["val_acc"]) / 100,
                               metadata={"epoch_num": round(epoch)})
        return stats