Ejemplo n.º 1
0
def test_multi_gpu_model_dp():
    """
    Make sure DP works
    :return:
    """
    if not torch.cuda.is_available():
        warnings.warn('test_multi_gpu_model_dp cannot run.'
                      ' Rerun on a GPU node to run this test')
        return
    if not torch.cuda.device_count() > 1:
        warnings.warn('test_multi_gpu_model_dp cannot run.'
                      ' Rerun on a node with 2+ GPUs to run this test')
        return
    model, hparams = get_model()
    trainer_options = dict(
        progress_bar=False,
        max_nb_epochs=1,
        train_percent_check=0.1,
        val_percent_check=0.1,
        gpus='-1'
    )

    run_gpu_model_test(trainer_options, model, hparams)

    # test memory helper functions
    memory.get_gpu_memory_map()
Ejemplo n.º 2
0
def test_multi_gpu_model_dp():
    """
    Make sure DP works
    :return:
    """
    if not can_run_gpu_test():
        return

    model, hparams = get_model()
    trainer_options = dict(show_progress_bar=False,
                           max_nb_epochs=1,
                           train_percent_check=0.1,
                           val_percent_check=0.1,
                           gpus='-1')

    run_gpu_model_test(trainer_options, model, hparams)

    # test memory helper functions
    memory.get_gpu_memory_map()
def test_multi_gpu_model_dp():
    """
    Make sure DP works
    :return:
    """
    testing_utils.reset_seed()

    if not testing_utils.can_run_gpu_test():
        return

    model, hparams = testing_utils.get_model()
    trainer_options = dict(
        show_progress_bar=False,
        distributed_backend='dp',
        max_nb_epochs=1,
        train_percent_check=0.1,
        val_percent_check=0.1,
        gpus='-1'
    )

    testing_utils.run_gpu_model_test(trainer_options, model, hparams)

    # test memory helper functions
    memory.get_gpu_memory_map()
Ejemplo n.º 4
0
    def __train(self):
        # run all epochs
        for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
            # update the lr scheduler
            if self.lr_schedulers is not None:
                for lr_scheduler in self.lr_schedulers:
                    lr_scheduler.step()

            model = self.__get_model()
            model.current_epoch = epoch_nb

            # hook
            if self.__is_function_implemented('on_epoch_start'):
                model = self.__get_model()
                model.on_epoch_start()

            self.current_epoch = epoch_nb
            self.total_batches = self.nb_tng_batches + self.nb_val_batches
            self.batch_loss_value = 0  # accumulated grads

            # init progbar when requested
            if self.progress_bar:
                self.prog_bar = tqdm.tqdm(range(self.total_batches),
                                          position=self.process_position)

            for batch_nb, data_batch in enumerate(self.tng_dataloader):
                self.batch_nb = batch_nb
                self.global_step += 1

                model = self.__get_model()
                model.global_step = self.global_step

                # stop when the flag is changed or we've gone past the amount
                #  requested in the batches
                self.total_batch_nb += 1
                met_batch_limit = batch_nb > self.nb_tng_batches
                if met_batch_limit:
                    break

                # ---------------
                # RUN TRAIN STEP
                # ---------------
                batch_result = self.__run_tng_batch(data_batch, batch_nb)
                early_stop_epoch = batch_result == -1

                # ---------------
                # RUN VAL STEP
                # ---------------
                is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
                if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
                    self.__run_validation()

                # when batch should be saved
                if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
                    if self.proc_rank == 0 and self.experiment is not None:
                        self.experiment.save()

                # when metrics should be logged
                if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
                    # count items in memory
                    # nb_params, nb_tensors = count_mem_items()

                    model = self.__get_model()
                    metrics = self.__tng_tqdm_dic

                    # add gpu memory
                    if self.on_gpu:
                        mem_map = get_gpu_memory_map()
                        metrics.update(mem_map)

                    # add norms
                    if self.track_grad_norm > 0:
                        model = self.__get_model()
                        grad_norm_dic = model.grad_norm(self.track_grad_norm)
                        metrics.update(grad_norm_dic)

                    if self.__is_function_implemented('on_tng_metrics'):
                        model.on_tng_metrics(metrics)

                    # log metrics
                    scalar_metrics = self.__metrics_to_scalars(
                        metrics, blacklist=self.__log_vals_blacklist())
                    if self.proc_rank == 0 and self.experiment is not None:
                        self.experiment.log(scalar_metrics, global_step=self.global_step)
                        self.experiment.save()

                # hook
                if self.__is_function_implemented('on_batch_end'):
                    model = self.__get_model()
                    model.on_batch_end()

                # end epoch early
                if early_stop_epoch:
                    break

            # hook
            if self.__is_function_implemented('on_epoch_end'):
                model = self.__get_model()
                model.on_epoch_end()

            # early stopping
            met_min_epochs = epoch_nb > self.min_nb_epochs
            if self.enable_early_stop and met_min_epochs:
                should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
                                                                    logs=self.__tng_tqdm_dic)

                # stop training
                stop = should_stop and met_min_epochs
                if stop:
                    return
Ejemplo n.º 5
0
    def run_tng_epoch(self):
        # before epoch hook
        if self.__is_function_implemented('on_epoch_start'):
            model = self.__get_model()
            model.on_epoch_start()

        # run epoch
        for batch_nb, data_batch in enumerate(self.tng_dataloader):
            self.batch_nb = batch_nb
            self.global_step += 1

            model = self.__get_model()
            model.global_step = self.global_step

            # stop when the flag is changed or we've gone past the amount
            #  requested in the batches
            self.total_batch_nb += 1
            met_batch_limit = batch_nb > self.nb_tng_batches
            if met_batch_limit:
                break

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            batch_result = self.__run_tng_batch(data_batch, batch_nb)
            early_stop_epoch = batch_result == -1

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
            can_check_epoch = (self.current_epoch +
                               1) % self.check_val_every_n_epoch == 0
            if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
                if can_check_epoch:
                    self.__run_evaluation(test=self.testing)

            # when batch should be saved
            if (batch_nb +
                    1) % self.log_save_interval == 0 or early_stop_epoch:
                if self.proc_rank == 0 and self.experiment is not None:
                    self.experiment.save()

            # when metrics should be logged
            if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
                # count items in memory
                # nb_params, nb_tensors = count_mem_items()

                model = self.__get_model()
                metrics = self.__tng_tqdm_dic

                # add gpu memory
                if self.on_gpu and self.log_gpu_memory:
                    mem_map = get_gpu_memory_map()
                    metrics.update(mem_map)

                # add norms
                if self.track_grad_norm > 0:
                    model = self.__get_model()
                    grad_norm_dic = model.grad_norm(self.track_grad_norm)
                    metrics.update(grad_norm_dic)

                if self.__is_function_implemented('on_tng_metrics'):
                    model.on_tng_metrics(metrics)

                # log metrics
                scalar_metrics = self.__metrics_to_scalars(
                    metrics, blacklist=self.__log_vals_blacklist())
                if self.proc_rank == 0 and self.experiment is not None:
                    self.experiment.log(scalar_metrics,
                                        global_step=self.global_step)
                    self.experiment.save()

            # end epoch early
            if early_stop_epoch:
                break

        # epoch end hook
        if self.__is_function_implemented('on_epoch_end'):
            model = self.__get_model()
            model.on_epoch_end()