Пример #1
0
def _get_val_loader(train_config):
    """
    Returns the validation loader and x-Data object.
    """
    _, x_val, y_val_value, y_val_policy, plys_to_end, _ = load_pgn_dataset(dataset_type="val",
                                                                           part_id=0,
                                                                           normalize=train_config.normalize,
                                                                           verbose=False,
                                                                           q_value_ratio=train_config.q_value_ratio)
    y_val_policy = prepare_policy(y_val_policy, train_config.select_policy_from_plane,
                                  train_config.sparse_policy_label, train_config.is_policy_from_plane_data)
    if train_config.framework == 'gluon':
        if train_config.use_wdl and train_config.use_plys_to_end:
            val_dataset = gluon.data.ArrayDataset(nd.array(x_val), nd.array(y_val_value), nd.array(y_val_policy),
                                                  nd.array(value_to_wdl_label(y_val_value)),
                                                  nd.array(prepare_plys_label(plys_to_end)))
        else:
            val_dataset = gluon.data.ArrayDataset(nd.array(x_val), nd.array(y_val_value), nd.array(y_val_policy))
        val_data = gluon.data.DataLoader(val_dataset, train_config.batch_size, shuffle=False,
                                         num_workers=train_config.cpu_count)
    elif train_config.framework == 'pytorch':
        if train_config.use_wdl and train_config.use_wdl:
            val_dataset = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val_value),
                                        torch.Tensor(y_val_policy),
                                        torch.Tensor(value_to_wdl_label(y_val_value)),
                                        torch.Tensor(prepare_plys_label(plys_to_end)))
        else:
            val_dataset = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val_value),
                                        torch.Tensor(y_val_policy))
        val_data = DataLoader(val_dataset, shuffle=True, batch_size=train_config.batch_size,
                              num_workers=train_config.cpu_count)
    return val_data, x_val
Пример #2
0
    def _get_train_loader(self, part_id):
        # load one chunk of the dataset from memory
        _, self.x_train, self.yv_train, self.yp_train, self.plys_to_end, _ = load_pgn_dataset(
            dataset_type="train",
            part_id=part_id,
            normalize=self.tc.normalize,
            verbose=False,
            q_value_ratio=self.tc.q_value_ratio)
        self.yp_train = prepare_policy(
            y_policy=self.yp_train,
            select_policy_from_plane=self.tc.select_policy_from_plane,
            sparse_policy_label=self.tc.sparse_policy_label,
            is_policy_from_plane_data=self.tc.is_policy_from_plane_data)

        # update the train_data object
        if self.tc.use_wdl and self.tc.use_plys_to_end:
            train_dataset = TensorDataset(
                torch.Tensor(self.x_train), torch.Tensor(self.yv_train),
                torch.Tensor(self.yp_train),
                torch.Tensor(value_to_wdl_label(self.yv_train)),
                torch.Tensor(prepare_plys_label(self.plys_to_end)))
        else:
            train_dataset = TensorDataset(torch.Tensor(self.x_train),
                                          torch.Tensor(self.yv_train),
                                          torch.Tensor(self.yp_train))
        train_loader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=self.tc.batch_size,
                                  num_workers=self.tc.cpu_count)
        return train_loader
Пример #3
0
def update_network(queue, nn_update_idx, k_steps_initial, max_lr,
                   symbol_filename, params_filename, cwd, convert_to_onnx):
    """
    Creates a new NN checkpoint in the model contender directory after training using the game files stored in the
     training directory
    :param queue: Queue object used to return items
    :param k_steps_initial: Initial amount of steps of the NN update
    :param nn_update_idx: Defines how many updates of the nn has already been done. This index should be incremented
    after every update.
    :param max_lr: Maximum learning rate used for the learning rate schedule
    :param symbol_filename: Architecture definition file
    :param params_filename: Weight file which will be loaded before training
    Updates the neural network with the newly acquired games from the replay memory
    :param cwd: Current working directory (must end with "/")
    :param convert_to_onnx: Boolean indicating if the network shall be exported to ONNX to allow TensorRT inference
    :return: k_steps_final
    """

    # set the context on CPU, switch to GPU if there is one available (strongly recommended for training)
    ctx = mx.gpu(train_config["device_id"]
                 ) if train_config["context"] == "gpu" else mx.cpu()
    # set a specific seed value for reproducibility
    nb_parts = len(glob.glob(main_config["planes_train_dir"] + '**/*.zip'))
    logging.info("number parts: %d" % nb_parts)

    if nb_parts <= 0:
        raise Exception(
            'No .zip files for training available. Check the path in main_config["planes_train_dir"]:'
            ' %s' % main_config["planes_train_dir"])

    _, x_val, y_val_value, y_val_policy, _, _ = load_pgn_dataset(
        dataset_type="val",
        part_id=0,
        normalize=train_config["normalize"],
        verbose=False,
        q_value_ratio=train_config["q_value_ratio"])

    y_val_policy = prepare_policy(y_val_policy,
                                  train_config["select_policy_from_plane"],
                                  train_config["sparse_policy_label"])

    symbol = mx.sym.load(symbol_filename)
    if not train_config["sparse_policy_label"]:
        symbol = add_non_sparse_cross_entropy(
            symbol, train_config["val_loss_factor"],
            train_config["value_output"] + "_output",
            train_config["policy_output"] + "_output")

    # calculate how many iterations per epoch exist
    nb_it_per_epoch = (len(x_val) * nb_parts) // train_config["batch_size"]
    # one iteration is defined by passing 1 batch and doing backprop
    total_it = int(nb_it_per_epoch * train_config["nb_epochs"])

    lr_schedule = CosineAnnealingSchedule(train_config["min_lr"], max_lr,
                                          max(total_it * .7, 1))
    lr_schedule = LinearWarmUp(lr_schedule,
                               start_lr=train_config["min_lr"],
                               length=max(total_it * .25, 1))
    momentum_schedule = MomentumSchedule(lr_schedule, train_config["min_lr"],
                                         max_lr, train_config["min_momentum"],
                                         train_config["max_momentum"])

    if train_config["select_policy_from_plane"]:
        val_iter = mx.io.NDArrayIter({'data': x_val}, {
            'value_label': y_val_value,
            'policy_label': y_val_policy
        }, train_config["batch_size"])
    else:
        val_iter = mx.io.NDArrayIter({'data': x_val}, {
            'value_label': y_val_value,
            'policy_label': y_val_policy
        }, train_config["batch_size"])

    # calculate how many iterations per epoch exist
    nb_it_per_epoch = (len(x_val) * nb_parts) // train_config["batch_size"]
    # one iteration is defined by passing 1 batch and doing backprop
    total_it = int(nb_it_per_epoch * train_config["nb_epochs"])

    input_shape = x_val[0].shape
    model = mx.mod.Module(symbol=symbol,
                          context=ctx,
                          label_names=['value_label', 'policy_label'])
    # mx.viz.print_summary(
    #     symbol,
    #     shape={'data': (1, input_shape[0], input_shape[1], input_shape[2])},
    # )
    model.bind(for_training=True,
               data_shapes=[('data',
                             (train_config["batch_size"], input_shape[0],
                              input_shape[1], input_shape[2]))],
               label_shapes=val_iter.provide_label)
    model.load_params(params_filename)

    metrics = [
        mx.metric.MSE(name='value_loss',
                      output_names=['value_output'],
                      label_names=['value_label']),
        mx.metric.create(acc_sign,
                         name='value_acc_sign',
                         output_names=['value_output'],
                         label_names=['value_label']),
    ]

    if train_config["sparse_policy_label"]:
        print("train with sparse labels")
        # the default cross entropy only supports sparse labels
        metrics.append(
            mx.metric.Accuracy(axis=1,
                               name='policy_acc',
                               output_names=['policy_output'],
                               label_names=['policy_label']))
        metrics.append(
            mx.metric.CrossEntropy(name='policy_loss',
                                   output_names=['policy_output'],
                                   label_names=['policy_label']))
    else:
        metrics.append(
            mx.metric.create(acc_distribution,
                             name='policy_acc',
                             output_names=['policy_output'],
                             label_names=['policy_label']))
        metrics.append(
            mx.metric.create(cross_entropy,
                             name='policy_loss',
                             output_names=['policy_output'],
                             label_names=['policy_label']))

    logging.info("Performance pre training")
    logging.info(model.score(val_iter, metrics))

    train_agent = TrainerAgentMXNET(
        model,
        symbol,
        val_iter,
        nb_parts,
        lr_schedule,
        momentum_schedule,
        total_it,
        train_config["optimizer_name"],
        wd=train_config["wd"],
        batch_steps=train_config["batch_steps"],
        k_steps_initial=k_steps_initial,
        cpu_count=train_config["cpu_count"],
        batch_size=train_config["batch_size"],
        normalize=train_config["normalize"],
        export_weights=train_config["export_weights"],
        export_grad_histograms=train_config["export_grad_histograms"],
        log_metrics_to_tensorboard=train_config["log_metrics_to_tensorboard"],
        ctx=ctx,
        metrics=metrics,
        use_spike_recovery=train_config["use_spike_recovery"],
        max_spikes=train_config["max_spikes"],
        spike_thresh=train_config["spike_thresh"],
        seed=None,
        val_loss_factor=train_config["val_loss_factor"],
        policy_loss_factor=train_config["policy_loss_factor"],
        select_policy_from_plane=train_config["select_policy_from_plane"],
        discount=train_config["discount"],
        sparse_policy_label=train_config["sparse_policy_label"],
        q_value_ratio=train_config["q_value_ratio"],
        cwd=cwd)
    # iteration counter used for the momentum and learning rate schedule
    cur_it = train_config["k_steps_initial"] * train_config["batch_steps"]
    (k_steps_final, val_value_loss_final, val_policy_loss_final,
     val_value_acc_sign_final,
     val_policy_acc_final), _ = train_agent.train(cur_it)

    if not train_config["sparse_policy_label"]:
        symbol = remove_no_sparse_cross_entropy(
            symbol, train_config["val_loss_factor"],
            train_config["value_output"] + "_output",
            train_config["policy_output"] + "_output")
    prefix = cwd + "model_contender/model-%.5f-%.5f-%.3f-%.3f" % (
        val_value_loss_final, val_policy_loss_final, val_value_acc_sign_final,
        val_policy_acc_final)

    sym_file = prefix + "-symbol.json"
    params_file = prefix + "-" + "%04d.params" % nn_update_idx
    symbol.save(sym_file)
    model.save_params(params_file)

    if convert_to_onnx:
        convert_mxnet_model_to_onnx(sym_file, params_file,
                                    ["value_out_output", "policy_out_output"],
                                    input_shape, [1, 8, 16], False)

    logging.info("k_steps_final %d" % k_steps_final)
    queue.put(k_steps_final)
Пример #4
0
def update_network(queue, nn_update_idx, symbol_filename, params_filename,
                   convert_to_onnx, main_config, train_config: TrainConfig,
                   model_contender_dir):
    """
    Creates a new NN checkpoint in the model contender directory after training using the game files stored in the
     training directory
    :param queue: Queue object used to return items
    :param nn_update_idx: Defines how many updates of the nn has already been done. This index should be incremented
    after every update.
    :param symbol_filename: Architecture definition file
    :param params_filename: Weight file which will be loaded before training
    Updates the neural network with the newly acquired games from the replay memory
    :param convert_to_onnx: Boolean indicating if the network shall be exported to ONNX to allow TensorRT inference
    :param main_config: Dict of the main_config (imported from main_config.py)
    :param train_config: Dict of the train_config (imported from train_config.py)
    :param model_contender_dir: String of the contender directory path
    :return: k_steps_final
    """

    # set the context on CPU, switch to GPU if there is one available (strongly recommended for training)
    ctx = mx.gpu(
        train_config.device_id) if train_config.context == "gpu" else mx.cpu()
    # set a specific seed value for reproducibility
    train_config.nb_parts = len(
        glob.glob(main_config["planes_train_dir"] + '**/*.zip'))
    logging.info("number parts for training: %d" % train_config.nb_parts)
    train_objects = TrainObjects()

    if train_config.nb_parts <= 0:
        raise Exception(
            'No .zip files for training available. Check the path in main_config["planes_train_dir"]:'
            ' %s' % main_config["planes_train_dir"])

    _, x_val, y_val_value, y_val_policy, _, _ = load_pgn_dataset(
        dataset_type="val",
        part_id=0,
        normalize=train_config.normalize,
        verbose=False,
        q_value_ratio=train_config.q_value_ratio)
    y_val_policy = prepare_policy(y_val_policy,
                                  train_config.select_policy_from_plane,
                                  train_config.sparse_policy_label,
                                  train_config.is_policy_from_plane_data)
    val_dataset = gluon.data.ArrayDataset(nd.array(x_val),
                                          nd.array(y_val_value),
                                          nd.array(y_val_policy))
    val_data = gluon.data.DataLoader(val_dataset,
                                     train_config.batch_size,
                                     shuffle=False,
                                     num_workers=train_config.cpu_count)

    symbol = mx.sym.load(symbol_filename)

    # calculate how many iterations per epoch exist
    nb_it_per_epoch = (len(x_val) *
                       train_config.nb_parts) // train_config.batch_size
    # one iteration is defined by passing 1 batch and doing backprop
    train_config.total_it = int(nb_it_per_epoch *
                                train_config.nb_training_epochs)

    train_objects.lr_schedule = CosineAnnealingSchedule(
        train_config.min_lr, train_config.max_lr,
        max(train_config.total_it * .7, 1))
    train_objects.lr_schedule = LinearWarmUp(train_objects.lr_schedule,
                                             start_lr=train_config.min_lr,
                                             length=max(
                                                 train_config.total_it * .25,
                                                 1))
    train_objects.momentum_schedule = MomentumSchedule(
        train_objects.lr_schedule, train_config.min_lr, train_config.max_lr,
        train_config.min_momentum, train_config.max_momentum)

    input_shape = x_val[0].shape
    inputs = mx.sym.var('data', dtype='float32')
    value_out = symbol.get_internals()[main_config['value_output'] + '_output']
    policy_out = symbol.get_internals()[main_config['policy_output'] +
                                        '_output']
    sym = mx.symbol.Group([value_out, policy_out])
    net = mx.gluon.SymbolBlock(sym, inputs)
    net.collect_params().load(params_filename, ctx)

    metrics_gluon = {
        'value_loss':
        metric.MSE(name='value_loss', output_names=['value_output']),
        'value_acc_sign':
        metric.create(acc_sign,
                      name='value_acc_sign',
                      output_names=['value_output'],
                      label_names=['value_label']),
    }

    if train_config.sparse_policy_label:
        print("train with sparse labels")
        # the default cross entropy only supports sparse labels
        metrics_gluon['policy_loss'] = metric.CrossEntropy(
            name='policy_loss',
            output_names=['policy_output'],
            label_names=['policy_label']),
        metrics_gluon['policy_acc'] = metric.Accuracy(
            axis=1,
            name='policy_acc',
            output_names=['policy_output'],
            label_names=['policy_label'])
    else:
        metrics_gluon['policy_loss'] = metric.create(
            cross_entropy,
            name='policy_loss',
            output_names=['policy_output'],
            label_names=['policy_label'])
        metrics_gluon['policy_acc'] = metric.create(
            acc_distribution,
            name='policy_acc',
            output_names=['policy_output'],
            label_names=['policy_label'])

    train_objects.metrics = metrics_gluon

    train_config.export_weights = False  # don't save intermediate weights
    train_agent = TrainerAgent(net,
                               val_data,
                               train_config,
                               train_objects,
                               use_rtpt=False)

    # iteration counter used for the momentum and learning rate schedule
    cur_it = train_config.k_steps_initial * train_config.batch_steps
    (k_steps_final, val_value_loss_final, val_policy_loss_final,
     val_value_acc_sign_final,
     val_policy_acc_final), _ = train_agent.train(cur_it)

    prefix = "%smodel-%.5f-%.5f-%.3f-%.3f" % (
        model_contender_dir, val_value_loss_final, val_policy_loss_final,
        val_value_acc_sign_final, val_policy_acc_final)

    sym_file = prefix + "-symbol.json"
    params_file = prefix + "-" + "%04d.params" % nn_update_idx

    # the export function saves both the architecture and the weights
    net.export(prefix, epoch=nn_update_idx)
    print()
    logging.info("Saved checkpoint to %s-%04d.params", prefix, nn_update_idx)

    if convert_to_onnx:
        convert_mxnet_model_to_onnx(sym_file, params_file,
                                    ["value_out_output", "policy_out_output"],
                                    input_shape, [1, 8, 16], False)

    logging.info("k_steps_final %d" % k_steps_final)
    queue.put(k_steps_final)
Пример #5
0
    def train(self, cur_it=None):  # Probably needs refactoring
        """
        Training model
        :param cur_it: Current iteration which is used for the learning rate and momentum schedule.
         If set to None it will be initialized
        """
        # Too many local variables (44/15) - Too many branches (18/12) - Too many statements (108/50)
        # set a custom seed for reproducibility
        random.seed(self.tc.seed)
        # define and initialize the variables which will be used
        t_s = time()
        # predefine the local variables that will be used in the training loop
        val_loss_best = val_p_acc_best = k_steps_best = val_metric_values_best = old_label = value_out = None
        patience_cnt = epoch = batch_proc_tmp = 0  # track on how many batches have been processed in this epoch
        k_steps = self.tc.k_steps_initial  # counter for thousands steps
        # calculate how many log states will be processed
        k_steps_end = round(self.tc.total_it / self.tc.batch_steps)
        # we use k-steps instead of epochs here
        if k_steps_end == 0:
            k_steps_end = 1

        if self.use_rtpt:
            self.rtpt = RTPT(name_initials=self.tc.name_initials,
                             experiment_name='crazyara',
                             max_iterations=k_steps_end -
                             self.tc.k_steps_initial)
        if cur_it is None:
            cur_it = self.tc.k_steps_initial * 1000
        nb_spikes = 0  # count the number of spikes that have been detected
        # initialize the loss to compare with, with a very high value
        old_val_loss = np.inf
        graph_exported = False  # create a state variable to check if the net architecture has been reported yet

        if not self.ordering:  # safety check to prevent eternal loop
            raise Exception(
                "You must have at least one part file in your planes-dataset directory!"
            )

        if self.use_rtpt:
            # Start the RTPT tracking
            self.rtpt.start()

        while True:  # Too many nested blocks (7/5)
            # reshuffle the ordering of the training game batches (shuffle works in place)
            random.shuffle(self.ordering)

            epoch += 1
            logging.info("EPOCH %d", epoch)
            logging.info("=========================")
            t_s_steps = time()

            for part_id in tqdm_notebook(self.ordering):
                # load one chunk of the dataset from memory
                _, x_train, yv_train, yp_train, _, _ = load_pgn_dataset(
                    dataset_type="train",
                    part_id=part_id,
                    normalize=self.tc.normalize,
                    verbose=False,
                    q_value_ratio=self.tc.q_value_ratio)

                yp_train = prepare_policy(
                    y_policy=yp_train,
                    select_policy_from_plane=self.tc.select_policy_from_plane,
                    sparse_policy_label=self.tc.sparse_policy_label,
                    is_policy_from_plane_data=self.tc.is_policy_from_plane_data
                )

                # update the train_data object
                train_dataset = gluon.data.ArrayDataset(
                    nd.array(x_train), nd.array(yv_train), nd.array(yp_train))
                train_data = gluon.data.DataLoader(
                    train_dataset,
                    batch_size=self.tc.batch_size,
                    shuffle=True,
                    num_workers=self.tc.cpu_count)

                for _, (data, value_label,
                        policy_label) in enumerate(train_data):
                    data = data.as_in_context(self._ctx)
                    value_label = value_label.as_in_context(self._ctx)
                    policy_label = policy_label.as_in_context(self._ctx)

                    # update a dummy metric to see a proper progress bar
                    #  (the metrics will get evaluated at the end of 100k steps)
                    if batch_proc_tmp > 0:
                        self.to.metrics["value_loss"].update(
                            old_label, value_out)

                    old_label = value_label
                    with autograd.record():
                        [value_out, policy_out] = self._net(data)
                        value_loss = self._l2_loss(value_out, value_label)
                        policy_loss = self._softmax_cross_entropy(
                            policy_out, policy_label)
                        # weight the components of the combined loss
                        combined_loss = (
                            self.tc.val_loss_factor * value_loss +
                            self.tc.policy_loss_factor * policy_loss)
                        # update a dummy metric to see a proper progress bar
                        # self._metrics['value_loss'].update(preds=value_out, labels=value_label)

                    combined_loss.backward()
                    learning_rate = self.to.lr_schedule(
                        cur_it)  # update the learning rate
                    self._trainer.set_learning_rate(learning_rate)
                    momentum = self.to.momentum_schedule(
                        cur_it)  # update the momentum
                    self._trainer._optimizer.momentum = momentum
                    self._trainer.step(data.shape[0])
                    cur_it += 1
                    batch_proc_tmp += 1
                    # add the graph representation of the network to the tensorboard log file
                    if not graph_exported and self.tc.log_metrics_to_tensorboard:
                        self.sum_writer.add_graph(self._net)
                        graph_exported = True

                    if batch_proc_tmp >= self.tc.batch_steps:  # show metrics every thousands steps
                        # log the current learning rate
                        # update batch_proc_tmp counter by subtracting the batch_steps
                        batch_proc_tmp = batch_proc_tmp - self.tc.batch_steps
                        ms_step = (
                            (time() - t_s_steps) /
                            self.tc.batch_steps) * 1000  # measure elapsed time
                        # update the counters
                        k_steps += 1
                        patience_cnt += 1
                        logging.info("Step %dK/%dK - %dms/step", k_steps,
                                     k_steps_end, ms_step)
                        logging.info("-------------------------")
                        logging.debug("Iteration %d/%d", cur_it,
                                      self.tc.total_it)
                        logging.debug("lr: %.7f - momentum: %.7f",
                                      learning_rate, momentum)
                        train_metric_values = evaluate_metrics(
                            self.to.metrics,
                            train_data,
                            self._net,
                            nb_batches=10,  #25,
                            ctx=self._ctx,
                            sparse_policy_label=self.tc.sparse_policy_label,
                            apply_select_policy_from_plane=self.tc.
                            select_policy_from_plane
                            and not self.tc.is_policy_from_plane_data)
                        val_metric_values = evaluate_metrics(
                            self.to.metrics,
                            self._val_data,
                            self._net,
                            nb_batches=None,
                            ctx=self._ctx,
                            sparse_policy_label=self.tc.sparse_policy_label,
                            apply_select_policy_from_plane=self.tc.
                            select_policy_from_plane
                            and not self.tc.is_policy_from_plane_data)
                        if self.use_rtpt:
                            # update process title according to loss
                            self.rtpt.step(
                                subtitle=
                                f"loss={val_metric_values['loss']:2.2f}")
                        if self.tc.use_spike_recovery and (
                                old_val_loss * self.tc.spike_thresh <
                                val_metric_values["loss"]
                                or np.isnan(val_metric_values["loss"])
                        ):  # check for spikes
                            nb_spikes += 1
                            logging.warning(
                                "Spike %d/%d occurred - val_loss: %.3f",
                                nb_spikes,
                                self.tc.max_spikes,
                                val_metric_values["loss"],
                            )
                            if nb_spikes >= self.tc.max_spikes:
                                val_loss = val_metric_values["loss"]
                                val_p_acc = val_metric_values["policy_acc"]
                                logging.debug(
                                    "The maximum number of spikes has been reached. Stop training."
                                )
                                # finally stop training because the number of lr drops has been achieved
                                print()
                                print("Elapsed time for training(hh:mm:ss): " +
                                      str(
                                          datetime.timedelta(
                                              seconds=round(time() - t_s))))

                                if self.tc.log_metrics_to_tensorboard:
                                    self.sum_writer.close()
                                return return_metrics_and_stop_training(
                                    k_steps, val_metric_values, k_steps_best,
                                    val_metric_values_best)

                            logging.debug("Recover to latest checkpoint")
                            model_path = self.tc.export_dir + "weights/model-%.5f-%.3f-%04d.params" % (
                                val_loss_best,
                                val_p_acc_best,
                                k_steps_best,
                            )  # Load the best model once again
                            logging.debug("load current best model:%s",
                                          model_path)
                            self._net.load_parameters(model_path,
                                                      ctx=self._ctx)
                            k_steps = k_steps_best
                            logging.debug("k_step is back at %d", k_steps_best)
                            # print the elapsed time
                            t_delta = time() - t_s_steps
                            print(" - %.ds" % t_delta)
                            t_s_steps = time()
                        else:
                            # update the val_loss_value to compare with using spike recovery
                            old_val_loss = val_metric_values["loss"]
                            # log the metric values to tensorboard
                            self._log_metrics(train_metric_values,
                                              global_step=k_steps,
                                              prefix="train_")
                            self._log_metrics(val_metric_values,
                                              global_step=k_steps,
                                              prefix="val_")

                            if self.tc.export_grad_histograms:
                                grads = []
                                # logging the gradients of parameters for checking convergence
                                for _, name in enumerate(self._param_names):
                                    if "bn" not in name and "batch" not in name and name != "policy_flat_plane_idx":
                                        grads.append(self._params[name].grad())
                                        self.sum_writer.add_histogram(
                                            tag=name,
                                            values=grads[-1],
                                            global_step=k_steps,
                                            bins=20)

                            # check if a new checkpoint shall be created
                            if val_loss_best is None or val_metric_values[
                                    "loss"] < val_loss_best:
                                # update val_loss_best
                                val_loss_best = val_metric_values["loss"]
                                val_p_acc_best = val_metric_values[
                                    "policy_acc"]
                                val_metric_values_best = val_metric_values
                                k_steps_best = k_steps

                                if self.tc.export_weights:
                                    prefix = self.tc.export_dir + "weights/model-%.5f-%.3f" \
                                             % (val_loss_best, val_p_acc_best)
                                    # the export function saves both the architecture and the weights
                                    self._net.export(prefix,
                                                     epoch=k_steps_best)
                                    print()
                                    logging.info(
                                        "Saved checkpoint to %s-%04d.params",
                                        prefix, k_steps_best)

                                patience_cnt = 0  # reset the patience counter
                            # print the elapsed time
                            t_delta = time() - t_s_steps
                            print(" - %.ds" % t_delta)
                            t_s_steps = time()

                            # log the samples per second metric to tensorboard
                            self.sum_writer.add_scalar(
                                tag="samples_per_second",
                                value={
                                    "hybrid_sync":
                                    data.shape[0] * self.tc.batch_steps /
                                    t_delta
                                },
                                global_step=k_steps,
                            )

                            # log the current learning rate
                            self.sum_writer.add_scalar(
                                tag="lr",
                                value=self.to.lr_schedule(cur_it),
                                global_step=k_steps)
                            # log the current momentum value
                            self.sum_writer.add_scalar(
                                tag="momentum",
                                value=self.to.momentum_schedule(cur_it),
                                global_step=k_steps)

                            if cur_it >= self.tc.total_it:

                                val_loss = val_metric_values["loss"]
                                val_p_acc = val_metric_values["policy_acc"]
                                logging.debug(
                                    "The number of given iterations has been reached"
                                )
                                # finally stop training because the number of lr drops has been achieved
                                print()
                                print("Elapsed time for training(hh:mm:ss): " +
                                      str(
                                          datetime.timedelta(
                                              seconds=round(time() - t_s))))

                                if self.tc.log_metrics_to_tensorboard:
                                    self.sum_writer.close()

                                return return_metrics_and_stop_training(
                                    k_steps, val_metric_values, k_steps_best,
                                    val_metric_values_best)