Example #1
0
        def loop(
            batch_iter: DataLoader,
            num_batches_to_use: Optional[int] = None,
            is_training: bool = True,
        ) -> mx.metric.Loss:
            nonlocal first_forward, time_elapsed, validation_idx

            tic = time.time()
            subtic = 0

            epoch_loss = mx.metric.Loss()
            batch_iter = itertools.islice(batch_iter, num_batches_to_use)

            it = tqdm(batch_iter, total=num_batches_to_use)
            for batch_no, batch in enumerate(it, start=1):
                # `batch` here is expected to be a dictionary whose fields
                # should correspond 1-to-1 with the network inputs
                # see below how `batch.values()` is fed into the network
                if first_forward:
                    tictic = time.time()
                    first_forward = False
                    _ = net(*batch.values())
                    self.callbacks.on_network_initialization_end(net)
                    subtic += time.time() - tictic

                with mx.autograd.record():  # type: ignore
                    # we set the mode explicitly as by default mxnet assumes
                    # predict mode and hence dropout layers are not used if
                    # the mode is not explicitly set to training
                    mode = (autograd.train_mode
                            if is_training else autograd.predict_mode)
                    with mode():
                        output = net(*batch.values())

                    # network can returns several outputs, the first being
                    # always the loss when having multiple outputs, the
                    # forward returns a list in the case of hybrid and a
                    # tuple otherwise we may wrap network outputs in the
                    # future to avoid this type check
                    if isinstance(output, (list, tuple)):
                        loss = output[0]
                    else:
                        loss = output

                    batch_size = loss.shape[0]

                # pylint: disable=no-member
                if not np.isfinite(
                        ndarray.sum(loss).asscalar()):  # type: ignore
                    logger.warning(
                        "Batch [%d] gave NaN loss and it will be ignored",
                        batch_no,
                    )
                else:
                    if is_training:
                        loss.backward()
                        trainer.step(batch_size)
                    epoch_loss.update(None, preds=loss)

                if is_training:
                    total_time_elapsed = (time_elapsed + time.time() - tic -
                                          subtic)

                    orig_lr = trainer.learning_rate
                    tictic = time.time()
                    self.callbacks.on_train_batch_end(net, total_time_elapsed)
                    subtic += time.time() - tictic
                    if trainer.learning_rate != orig_lr:
                        logger.info(
                            "Trainer learning rate set to %f",
                            trainer.learning_rate,
                        )

                lv = _loss_value(epoch_loss)
                it.set_postfix(
                    ordered_dict={
                        ("" if is_training else "validation_") + "avg_epoch_loss":
                        lv,
                    },
                    refresh=False,
                )

                # Check if should finish
                if is_training:
                    if total_time_elapsed > self.training_time:  # type: ignore
                        time_elapsed = total_time_elapsed  # type: ignore
                        break
                    if len(self.validation_milestones) > validation_idx and (
                            total_time_elapsed  # type: ignore
                            > self.validation_milestones[validation_idx]):
                        time_elapsed = total_time_elapsed  # type: ignore
                        validation_idx += 1
                        break
                # If validating, call the callback with the loss
                else:
                    self.callbacks.on_validation_epoch_end(lv)

            # mark epoch end time and log time cost of current epoch
            toc = time.time()
            logger.info("Elapsed time %.3f seconds", toc - tic)
            logger.info(
                "Evaluation metric '%s'=%f",
                ("" if is_training else "validation_") + "epoch_loss",
                lv,  # type: ignore
            )

            return epoch_loss
Example #2
0
            def loop(  # todo call run epoch
                epoch_no,
                batch_iter,
                num_batches_to_use: Optional[int] = None,
                is_training: bool = True,
            ) -> mx.metric.Loss:
                nonlocal first_forward
                tic = time.time()

                epoch_loss = mx.metric.Loss()

                if is_training:
                    # We should not call this method if we haven't compiled the
                    # network yet. Instead, this callback is called after
                    # network initialization.
                    if not first_forward:
                        self.callbacks.on_train_epoch_start(
                            training_network=net)
                else:
                    self.callbacks.on_validation_epoch_start(
                        training_network=net)

                batch_iter = itertools.islice(batch_iter, num_batches_to_use)

                it = tqdm(batch_iter, total=num_batches_to_use)
                for batch_no, batch in enumerate(it, start=1):
                    # `batch` here is expected to be a dictionary whose fields
                    # should correspond 1-to-1 with the network inputs
                    # see below how `batch.values()` is fed into the network
                    if self.halt:
                        break

                    if first_forward:
                        first_forward = False
                        _ = net(*batch.values())

                        self.callbacks.on_network_initializing_end(
                            training_network=net)

                        # Call the batch start callback as the model was not
                        # compiled before
                        self.callbacks.on_train_epoch_start(
                            training_network=net)

                    with mx.autograd.record():
                        # we set the mode explicitly as by default mxnet assumes
                        # predict mode and hence dropout layers are not used if
                        # the mode is not explicitly set to training
                        mode = (autograd.train_mode
                                if is_training else autograd.predict_mode)
                        with mode():
                            output = net(*batch.values())

                        # network can returns several outputs, the first being
                        # always the loss when having multiple outputs, the
                        # forward returns a list in the case of hybrid and a
                        # tuple otherwise we may wrap network outputs in the
                        # future to avoid this type check
                        if isinstance(output, (list, tuple)):
                            loss = output[0]
                        else:
                            loss = output

                        batch_size = loss.shape[0]

                    if not np.isfinite(ndarray.sum(loss).asscalar()):
                        logger.warning(
                            "Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored",
                            batch_no,
                            epoch_no,
                        )
                    else:
                        if is_training:
                            loss.backward()
                            trainer.step(batch_size)

                            self.callbacks.on_train_batch_end(
                                training_network=net)
                        else:
                            self.callbacks.on_validation_batch_end(
                                training_network=net)

                        epoch_loss.update(None, preds=loss)

                    lv = loss_value(epoch_loss)
                    it.set_postfix(
                        ordered_dict={
                            "epoch":
                            f"{epoch_no + 1}/{self.epochs}",
                            ("" if is_training else "validation_") + "avg_epoch_loss":
                            lv,
                        },
                        refresh=False,
                    )
                    # print out parameters of the network at the first pass
                    if batch_no == 1 and epoch_no == 0:
                        net_name = type(net).__name__
                        num_model_param = self.count_model_params(net)
                        logger.info(
                            f"Number of parameters in {net_name}: {num_model_param}"
                        )
                it.close()

                # mark epoch end time and log time cost of current epoch
                toc = time.time()
                logger.info(
                    "Epoch[%d] Elapsed time %.3f seconds",
                    epoch_no,
                    (toc - tic),
                )

                logger.info(
                    "Epoch[%d] Evaluation metric '%s'=%f",
                    epoch_no,
                    ("" if is_training else "validation_") + "epoch_loss",
                    lv,
                )

                return epoch_loss
Example #3
0
                def loop(epoch_no,
                         batch_iter,
                         is_training: bool = True) -> mx.metric.Loss:
                    nonlocal first_forward
                    tic = time.time()

                    epoch_loss = mx.metric.Loss()

                    # use averaged model for validation
                    if not is_training and isinstance(
                            self.avg_strategy, IterationAveragingStrategy):
                        self.avg_strategy.load_averaged_model(net)

                    with tqdm(batch_iter) as it:
                        for batch_no, data_entry in enumerate(it, start=1):
                            if self.halt:
                                break

                            inputs = [data_entry[k] for k in input_names]

                            if first_forward:
                                first_forward = False
                                _ = net(*inputs)
                                if self.post_initialize_cb:
                                    self.post_initialize_cb(net)

                            with mx.autograd.record():
                                output = net(*inputs)

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                            if not np.isfinite(ndarray.sum(loss).asscalar()):
                                logger.warning(
                                    "Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored",
                                    batch_no,
                                    epoch_no,
                                )
                            else:
                                if is_training:
                                    loss.backward()
                                    trainer.step(batch_size)

                                    # iteration averaging in training
                                    if isinstance(
                                            self.avg_strategy,
                                            IterationAveragingStrategy,
                                    ):
                                        self.avg_strategy.apply(net)

                                epoch_loss.update(None, preds=loss)

                            lv = loss_value(epoch_loss)
                            it.set_postfix(
                                ordered_dict={
                                    "epoch":
                                    f"{epoch_no + 1}/{self.epochs}",
                                    ("" if is_training else "validation_") + "avg_epoch_loss":
                                    lv,
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logger.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )
                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logger.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    logger.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        ("" if is_training else "validation_") + "epoch_loss",
                        lv,
                    )

                    if not is_training and isinstance(
                            self.avg_strategy, IterationAveragingStrategy):
                        # bring back the cached model
                        self.avg_strategy.load_cached_model(net)

                    return epoch_loss
Example #4
0
                def loop(
                    epoch_no,
                    batch_iter,
                    num_batches_to_use: Optional[int] = None,
                    is_training: bool = True,
                ) -> mx.metric.Loss:
                    nonlocal first_forward
                    tic = time.time()

                    epoch_loss = mx.metric.Loss()

                    # use averaged model for validation
                    if not is_training and isinstance(
                        self.avg_strategy, IterationAveragingStrategy
                    ):
                        self.avg_strategy.load_averaged_model(net)

                    batch_iter = itertools.islice(
                        batch_iter, num_batches_to_use
                    )

                    with tqdm(batch_iter, total=num_batches_to_use) as it:
                        for batch_no, batch in enumerate(it, start=1):
                            # `batch` here is expected to be a dictionary whose fields
                            # should correspond 1-to-1 with the network inputs
                            # see below how `batch.values()` is fed into the network

                            if first_forward:
                                first_forward = False
                                _ = net(*batch.values())
                                if self.post_initialize_cb:
                                    self.post_initialize_cb(net)

                            with mx.autograd.record():
                                # we set the mode explicitly as by default mxnet assumes predict mode and hence
                                # dropout layers are not used if the mode is not explicitly set to training
                                mode = (
                                    autograd.train_mode
                                    if is_training
                                    else autograd.predict_mode
                                )
                                with mode():
                                    output = net(*batch.values())

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                                batch_size = loss.shape[0]

                            if not np.isfinite(ndarray.sum(loss).asscalar()):
                                logger.warning(
                                    "Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored",
                                    batch_no,
                                    epoch_no,
                                )
                            else:
                                if is_training:
                                    loss.backward()
                                    trainer.step(batch_size)

                                    # iteration averaging in training
                                    if isinstance(
                                        self.avg_strategy,
                                        IterationAveragingStrategy,
                                    ):
                                        self.avg_strategy.apply(net)

                                epoch_loss.update(None, preds=loss)

                            lv = loss_value(epoch_loss)
                            it.set_postfix(
                                ordered_dict={
                                    "epoch": f"{epoch_no + 1}/{self.epochs}",
                                    ("" if is_training else "validation_")
                                    + "avg_epoch_loss": lv,
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logger.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )
                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logger.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    logger.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        ("" if is_training else "validation_") + "epoch_loss",
                        lv,
                    )

                    if not is_training and isinstance(
                        self.avg_strategy, IterationAveragingStrategy
                    ):
                        # bring back the cached model
                        self.avg_strategy.load_cached_model(net)

                    return epoch_loss