示例#1
0
    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config();       
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs) 
        Miner.check_config(config)
        self.config = config

        # ---- Model ----
        self.model = BertMLMSynapse( self.config )

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = self.config.miner.learning_rate, momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(BertMLMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('ag_news')['train']
        # The collator accepts a list [ dict{'input_ids, ...; } ] where the internal dict 
        # is produced by the tokenizer.
        self.data_collator = DataCollatorForLanguageModeling (
            tokenizer=bittensor.__tokenizer__(), mlm=True, mlm_probability=0.15
        )
        super( Miner, self ).__init__( self.config, **kwargs )
示例#2
0
    def __init__(self, config: Munch):
        self.config = config

        # ---- Neuron ----
        self.neuron = Neuron(self.config)

        # ---- Model ----
        self.model = BertNSPSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.session.learning_rate,
                                         momentum=self.config.session.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('bookcorpus')

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir=self.config.session.full_path)
        if self.config.session.record_log:
            logger.add(
                self.config.session.full_path + "/{}_{}.log".format(
                    self.config.session.name, self.config.session.trial_uid),
                format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
示例#3
0
    def __init__(self, config: Munch = None):
        if config == None:
            config = Miner.build_config(); logger.info(bittensor.config.Config.toString(config))
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = GPT2LMSynapse( self.config )

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = self.config.miner.learning_rate, momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(GPT2LMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        # self.dataset = load_dataset('ag_news')['train']
        self.dataset = AdamCorpus(self.config.miner.custom_datasets)

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir = self.config.miner.full_path)
        if self.config.miner.record_log:
            logger.add(self.config.miner.full_path + "/{}_{}.log".format(self.config.miner.name, self.config.miner.trial_uid),format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
示例#4
0
    def __init__(self, config: Munch):
        self.config = config

        # ---- Neuron ----
        self.neuron = Neuron(self.config)

        # ---- Model ----
        self.model = BertMLMSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.session.learning_rate,
                                         momentum=self.config.session.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('bookcorpus')['train']
        # The collator accepts a list [ dict{'input_ids, ...; } ] where the internal dict
        # is produced by the tokenizer.
        self.data_collator = DataCollatorForLanguageModeling(
            tokenizer=bittensor.__tokenizer__(),
            mlm=True,
            mlm_probability=0.15)

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir=self.config.session.full_path)
        if self.config.session.record_log:
            logger.add(
                self.config.session.full_path + "/{}_{}.log".format(
                    self.config.session.name, self.config.session.trial_uid),
                format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
示例#5
0
    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config()
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs)
        Miner.check_config(config)
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = GPT2LMSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.miner.learning_rate,
                                         momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(GPT2LMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # The Genesis Dataset:
        # The dataset used to train Adam and his first 100 children.
        self.dataset = AdamCorpus(self.config.miner.custom_dataset)

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir=self.config.miner.full_path)
        if self.config.miner.record_log:
            logger.add(
                self.config.miner.full_path + "/{}_{}.log".format(
                    self.config.miner.name, self.config.miner.trial_uid),
                format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
示例#6
0
    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config();       
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs) 
        Miner.check_config(config)
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = BertNSPSynapse( self.config )

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = self.config.miner.learning_rate, momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(BertNSPSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: News headlines
        self.dataset = load_dataset('ag_news')['train']


        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir = self.config.miner.full_path)
        if self.config.miner.record_log:
            logger.add(self.config.miner.full_path + "/{}_{}.log".format(self.config.miner.name, self.config.miner.trial_uid),format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
示例#7
0
    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config();       
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs) 
        Miner.check_config(config)
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = XLMSynapse( self.config )

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = self.config.miner.learning_rate, momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(XLMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('amazon_reviews_multi', 'en')['train']

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.config.synapse.device:
            self.device = torch.device(self.config.synapse.device)

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir = self.config.miner.full_path)
        if self.config.miner.record_log:
            logger.add(self.config.miner.full_path + "/{}_{}.log".format(self.config.miner.name, self.config.miner.trial_uid),format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
    def test_warmup_cosine_hard_restart_scheduler(self):
        scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
        lrs = unwrap_schedule(scheduler, self.num_steps)
        expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
        self.assertEqual(len(lrs[0]), 1)
        self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)

        scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
        lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
        self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
示例#9
0
    def __init__(self, config: Munch = None, **kwargs):
        # ---- Load Config ----
        if config == None:
            config = Miner.default_config()
        config = copy.deepcopy(config)
        bittensor.config.Config.update_with_kwargs(config, kwargs)
        Miner.check_config(config)
        logger.info(bittensor.config.Config.toString(config))
        self.config = config

        # ---- Row Weights ----
        self.row_weights = torch.ones([1])

        # ---- Nucleus ----
        self.synapse = XLMSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.synapse.parameters(),
                                         lr=self.config.miner.learning_rate,
                                         momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Dataset ----
        self.dataset = GenesisTextDataloader(
            self.config.miner.batch_size_train, 20)
        super(Miner, self).__init__(self.config, **kwargs)
示例#10
0
    def configure_optimizers(self):
        model = self.model
        optimizer = torch.optim.AdamW(model.parameters(), lr=self.hparams.lr)
        scheduler = WarmupCosineWithHardRestartsSchedule(optimizer=optimizer,
                                                         warmup_steps=1,
                                                         t_total=5)

        return [optimizer], [scheduler]
示例#11
0
    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config()
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs)
        Miner.check_config(config)
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = BertMLMSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.miner.learning_rate,
                                         momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(BertMLMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('ag_news')['train']
        # The collator accepts a list [ dict{'input_ids, ...; } ] where the internal dict
        # is produced by the tokenizer.
        self.data_collator = DataCollatorForLanguageModeling(
            tokenizer=bittensor.__tokenizer__(),
            mlm=True,
            mlm_probability=0.15)

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir=self.config.miner.full_path)
        if self.config.miner.record_log == True:
            filepath = self.config.miner.full_path + "/{}_{}.log".format(
                self.config.miner.name, self.config.miner.trial_uid),
            logger.add(
                filepath,
                format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
                rotation="250 MB",
                retention="10 days")
示例#12
0
class Miner():
    """
    Initializes, trains, and tests models created inside of 'bittensor/synapses'. 
    During instantiation, this class takes a config as a [Munch](https://github.com/Infinidat/munch) object. 
    """
    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config()
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs)
        Miner.check_config(config)
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = XLMSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.miner.learning_rate,
                                         momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(XLMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('amazon_reviews_multi', 'en')['train']

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        if self.config.synapse.device:
            self.device = torch.device(self.config.synapse.device)

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir=self.config.miner.full_path)
        if self.config.miner.record_log == True:
            filepath = f"{self.config.miner.full_path}/{self.config.miner.name}_ {self.config.miner.trial_uid}.log"
            logger.add(
                filepath,
                format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
                rotation="250 MB",
                retention="10 days")

    @staticmethod
    def default_config() -> Munch:
        parser = argparse.ArgumentParser()
        Miner.add_args(parser)
        config = bittensor.config.Config.to_config(parser)
        return config

    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        parser.add_argument('--miner.learning_rate',
                            default=0.01,
                            type=float,
                            help='Training initial learning rate.')
        parser.add_argument('--miner.momentum',
                            default=0.98,
                            type=float,
                            help='Training initial momentum for SGD.')
        parser.add_argument('--miner.n_epochs',
                            default=int(sys.maxsize),
                            type=int,
                            help='Number of training epochs.')
        parser.add_argument('--miner.epoch_length',
                            default=500,
                            type=int,
                            help='Iterations of training per epoch')
        parser.add_argument('--miner.batch_size_train',
                            default=1,
                            type=int,
                            help='Training batch size.')
        parser.add_argument(
            '--miner.sync_interval',
            default=100,
            type=int,
            help='Batches before we sync with chain and emit new weights.')
        parser.add_argument('--miner.log_interval',
                            default=10,
                            type=int,
                            help='Batches before we log miner info.')
        parser.add_argument(
            '--miner.accumulation_interval',
            default=1,
            type=int,
            help='Batches before we apply acummulated gradients.')
        parser.add_argument(
            '--miner.apply_remote_gradients',
            default=False,
            type=bool,
            help=
            'If true, neuron applies gradients which accumulate from remotes calls.'
        )
        parser.add_argument(
            '--miner.root_dir',
            default='~/.bittensor/miners/',
            type=str,
            help='Root path to load and save data associated with each miner')
        parser.add_argument(
            '--miner.name',
            default='xlm_wiki',
            type=str,
            help='Trials for this miner go in miner.root / miner.name')
        parser.add_argument(
            '--miner.trial_uid',
            default=str(time.time()).split('.')[0],
            type=str,
            help='Saved models go in miner.root_dir / miner.name / miner.uid')
        parser.add_argument('--miner.record_log',
                            default=False,
                            help='Record all logs when running this miner')
        parser.add_argument(
            '--miner.config_file',
            type=str,
            help=
            'config file to run this neuron, if not using cmd line arguments.')
        parser.add_argument('--debug',
                            dest='debug',
                            action='store_true',
                            help='''Turn on bittensor debugging information''')
        parser.set_defaults(debug=False)
        XLMSynapse.add_args(parser)
        bittensor.neuron.Neuron.add_args(parser)

    @staticmethod
    def check_config(config: Munch):
        if config.debug:
            bittensor.__log_level__ = 'TRACE'
            logger.debug('DEBUG is ON')
        else:
            logger.info('DEBUG is OFF')
        assert config.miner.momentum > 0 and config.miner.momentum < 1, "momentum must be a value between 0 and 1"
        assert config.miner.batch_size_train > 0, "batch_size_train must be a positive value"
        assert config.miner.learning_rate > 0, "learning_rate must be a positive value."
        full_path = '{}/{}/{}'.format(config.miner.root_dir, config.miner.name,
                                      config.miner.trial_uid)
        config.miner.full_path = os.path.expanduser(full_path)
        if not os.path.exists(config.miner.full_path):
            os.makedirs(config.miner.full_path)

    # --- Main loop ----
    def run(self):

        # ---- Subscribe ----
        with self.neuron:

            # ---- Weights ----
            self.row = self.neuron.metagraph.row.to(self.model.device)

            # --- Run state ---
            self.global_step = 0
            self.best_train_loss = math.inf

            # --- Loop for epochs ---
            for self.epoch in range(self.config.miner.n_epochs):
                try:
                    # ---- Serve ----
                    self.neuron.axon.serve(self.model)

                    # ---- Train Model ----
                    self.train()
                    self.scheduler.step()

                    # If model has borked for some reason, we need to make sure it doesn't emit weights
                    # Instead, reload into previous version of model
                    if torch.any(
                            torch.isnan(
                                torch.cat([
                                    param.view(-1)
                                    for param in self.model.parameters()
                                ]))):
                        self.model, self.optimizer = self.model_toolbox.load_model(
                            self.config)
                        continue

                    # ---- Emitting weights ----
                    self.neuron.metagraph.set_weights(
                        self.row, wait_for_inclusion=True
                    )  # Sets my row-weights on the chain.

                    # ---- Sync metagraph ----
                    self.neuron.metagraph.sync(
                    )  # Pulls the latest metagraph state (with my update.)
                    self.row = self.neuron.metagraph.row.to(self.model.device)

                    # --- Epoch logs ----
                    print(self.neuron.axon.__full_str__())
                    print(self.neuron.dendrite.__full_str__())
                    print(self.neuron.metagraph)

                    # ---- Update Tensorboard ----
                    self.neuron.dendrite.__to_tensorboard__(
                        self.tensorboard, self.global_step)
                    self.neuron.metagraph.__to_tensorboard__(
                        self.tensorboard, self.global_step)
                    self.neuron.axon.__to_tensorboard__(
                        self.tensorboard, self.global_step)

                    # ---- Save best loss and model ----
                    if self.training_loss and self.epoch % 10 == 0 and self.training_loss < self.best_train_loss:
                        self.best_train_loss = self.training_loss / 10  # update best train loss
                        self.model_toolbox.save_model(
                            self.config.miner.full_path, {
                                'epoch':
                                self.epoch,
                                'model_state_dict':
                                self.model.state_dict(),
                                'loss':
                                self.best_train_loss,
                                'optimizer_state_dict':
                                self.optimizer.state_dict(),
                            })
                        self.tensorboard.add_scalar('Neuron/Train_loss',
                                                    self.training_loss,
                                                    self.global_step)

                # --- Catch Errors ----
                except Exception as e:
                    logger.error(
                        'Exception in training script with error: {}, {}', e,
                        traceback.format_exc())
                    logger.info('Continuing to train.')

    # ---- Train Epoch ----
    def train(self):
        self.training_loss = 0.0
        for local_step in range(self.config.miner.epoch_length):
            # ---- Forward pass ----
            inputs = nextbatch(self.dataset,
                               self.config.miner.batch_size_train,
                               bittensor.__tokenizer__())
            output = self.model.remote_forward(
                self.neuron,
                inputs.to(self.model.device),
                training=True,
            )

            # ---- Backward pass ----
            loss = output.local_target_loss + output.distillation_loss + output.remote_target_loss
            loss.backward()  # Accumulates gradients on the model.
            self.optimizer.step()  # Applies accumulated gradients.
            self.optimizer.zero_grad(
            )  # Zeros out gradients for next accummulation

            # ---- Train row weights ----
            batch_weights = torch.mean(output.router.weights, axis=0).to(
                self.model.device)  # Average over batch.
            self.row = (
                1 -
                0.03) * self.row + 0.03 * batch_weights  # Moving avg update.
            self.row = F.normalize(self.row, p=1,
                                   dim=0)  # Ensure normalization.

            # ---- Step logs ----
            logger.info(
                'GS: {} LS: {} Epoch: {}\tLocal Target Loss: {}\tRemote Target Loss: {}\tDistillation Loss: {}\tAxon: {}\tDendrite: {}',
                colored('{}'.format(self.global_step), 'red'),
                colored('{}'.format(local_step), 'blue'),
                colored('{}'.format(self.epoch), 'green'),
                colored('{:.4f}'.format(output.local_target_loss.item()),
                        'green'),
                colored('{:.4f}'.format(output.remote_target_loss.item()),
                        'blue'),
                colored('{:.4f}'.format(output.distillation_loss.item()),
                        'red'), self.neuron.axon, self.neuron.dendrite)
            logger.info('Codes: {}', output.router.return_codes.tolist())

            self.tensorboard.add_scalar('Neuron/Rloss',
                                        output.remote_target_loss.item(),
                                        self.global_step)
            self.tensorboard.add_scalar('Neuron/Lloss',
                                        output.local_target_loss.item(),
                                        self.global_step)
            self.tensorboard.add_scalar('Neuron/Dloss',
                                        output.distillation_loss.item(),
                                        self.global_step)

            # ---- Step increments ----
            self.global_step += 1
            self.training_loss += output.local_target_loss.item()

            # --- Memory clean up ----
            torch.cuda.empty_cache()
            del output
示例#13
0
class Session():
    def __init__(self, config: Munch):
        self.config = config

        # ---- Neuron ----
        self.neuron = Neuron(self.config)

        # ---- Model ----
        self.model = BertNSPSynapse(self.config)

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.session.learning_rate,
                                         momentum=self.config.session.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(
            self.optimizer, 50, 300)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('bookcorpus')

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir=self.config.session.full_path)
        if self.config.session.record_log:
            logger.add(
                self.config.session.full_path + "/{}_{}.log".format(
                    self.config.session.name, self.config.session.trial_uid),
                format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")

    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        parser.add_argument('--session.learning_rate',
                            default=0.01,
                            type=float,
                            help='Training initial learning rate.')
        parser.add_argument('--session.momentum',
                            default=0.98,
                            type=float,
                            help='Training initial momentum for SGD.')
        parser.add_argument('--session.epoch_length',
                            default=10,
                            type=int,
                            help='Iterations of training per epoch')
        parser.add_argument('--session.batch_size_train',
                            default=1,
                            type=int,
                            help='Training batch size.')
        parser.add_argument(
            '--session.sync_interval',
            default=100,
            type=int,
            help='Batches before we sync with chain and emit new weights.')
        parser.add_argument('--session.log_interval',
                            default=10,
                            type=int,
                            help='Batches before we log session info.')
        parser.add_argument(
            '--session.accumulation_interval',
            default=1,
            type=int,
            help='Batches before we apply acummulated gradients.')
        parser.add_argument(
            '--session.apply_remote_gradients',
            default=False,
            type=bool,
            help=
            'If true, neuron applies gradients which accumulate from remotes calls.'
        )
        parser.add_argument(
            '--session.root_dir',
            default='~/.bittensor/sessions/',
            type=str,
            help='Root path to load and save data associated with each session'
        )
        parser.add_argument(
            '--session.name',
            default='bert-nsp',
            type=str,
            help='Trials for this session go in session.root / session.name')
        parser.add_argument(
            '--session.trial_uid',
            default=str(time.time()).split('.')[0],
            type=str,
            help=
            'Saved models go in session.root_dir / session.name / session.uid')
        parser.add_argument('--session.record_log',
                            default=True,
                            help='Record all logs when running this session')
        parser.add_argument(
            '--session.config_file',
            type=str,
            help=
            'config file to run this neuron, if not using cmd line arguments.')
        BertNSPSynapse.add_args(parser)
        Neuron.add_args(parser)

    @staticmethod
    def check_config(config: Munch):
        assert config.session.momentum > 0 and config.session.momentum < 1, "momentum must be a value between 0 and 1"
        assert config.session.batch_size_train > 0, "batch_size_train must a positive value"
        assert config.session.learning_rate > 0, "learning_rate must be a positive value."
        full_path = '{}/{}/{}'.format(config.session.root_dir,
                                      config.session.name,
                                      config.session.trial_uid)
        config.session.full_path = os.path.expanduser(full_path)
        if not os.path.exists(config.session.full_path):
            os.makedirs(config.session.full_path)
        BertNSPSynapse.check_config(config)
        Neuron.check_config(config)

    # --- Main loop ----
    def run(self):

        # ---- Subscribe ----
        with self.neuron:

            # ---- Weights ----
            self.row = self.neuron.metagraph.row

            # --- Run state ---
            self.epoch = -1
            self.global_step = 0
            self.best_train_loss = math.inf

            # --- Loop forever ---
            while True:
                try:
                    self.epoch += 1

                    # ---- Serve ----
                    self.neuron.axon.serve(self.model)

                    # ---- Train Model ----
                    self.train()
                    self.scheduler.step()

                    # ---- Emit row-weights ----
                    self.neuron.metagraph.emit(
                        self.row, wait_for_inclusion=True
                    )  # Sets my row-weights on the chain.

                    # ---- Sync metagraph ----
                    self.neuron.metagraph.sync(
                    )  # Pulls the latest metagraph state (with my update.)
                    self.row = self.neuron.metagraph.row

                    # --- Epoch logs ----
                    print(self.neuron.axon.__full_str__())
                    print(self.neuron.dendrite.__full_str__())
                    print(self.neuron.metagraph)

                    # ---- Update Tensorboard ----
                    self.neuron.dendrite.__to_tensorboard__(
                        self.tensorboard, self.global_step)
                    self.neuron.metagraph.__to_tensorboard__(
                        self.tensorboard, self.global_step)
                    self.neuron.axon.__to_tensorboard__(
                        self.tensorboard, self.global_step)

                    # ---- Save best loss and model ----
                    if self.training_loss and self.epoch % 10 == 0:
                        if self.training_loss < self.best_train_loss:
                            self.best_train_loss = self.training_loss  # update best train loss
                            logger.info(
                                'Saving/Serving model: epoch: {}, loss: {}, path: {}/model.torch'
                                .format(self.epoch, self.best_train_loss,
                                        self.config.session.full_path))
                            torch.save(
                                {
                                    'epoch': self.epoch,
                                    'model': self.model.state_dict(),
                                    'loss': self.best_train_loss
                                }, "{}/model.torch".format(
                                    self.config.session.full_path))
                            self.tensorboard.add_scalar(
                                'Neuron/Train_loss', self.training_loss,
                                self.global_step)

                # --- Catch Errors ----
                except Exception as e:
                    logger.error('Exception in training script with error: {}',
                                 e)
                    logger.info(traceback.print_exc())
                    logger.info('Continuing to train.')
                    time.sleep(1)

    # ---- Train Epoch ----
    def train(self):
        self.training_loss = 0.0
        for local_step in range(self.config.session.epoch_length):
            # ---- Forward pass ----
            inputs, targets = nsp_batch(self.dataset['train'],
                                        self.config.session.batch_size_train,
                                        bittensor.__tokenizer__())
            output = self.model.remote_forward(
                self.neuron,
                inputs=inputs['input_ids'].to(self.model.device),
                attention_mask=inputs['attention_mask'].to(self.model.device),
                targets=targets.to(self.model.device))

            # ---- Backward pass ----
            loss = output.local_target_loss + output.distillation_loss + output.remote_target_loss
            loss.backward()  # Accumulates gradients on the model.
            self.optimizer.step()  # Applies accumulated gradients.
            self.optimizer.zero_grad(
            )  # Zeros out gradients for next accummulation

            # ---- Train row weights ----
            batch_weights = torch.mean(output.dendrite.weights,
                                       axis=0)  # Average over batch.
            self.row = (
                1 -
                0.03) * self.row + 0.03 * batch_weights  # Moving avg update.
            self.row = F.normalize(self.row, p=1,
                                   dim=0)  # Ensure normalization.

            # ---- Step logs ----
            logger.info(
                'GS: {} LS: {} Epoch: {}\tLocal Target Loss: {}\tRemote Target Loss: {}\tDistillation Loss: {}\tAxon: {}\tDendrite: {}',
                colored('{}'.format(self.global_step), 'red'),
                colored('{}'.format(local_step), 'blue'),
                colored('{}'.format(self.epoch), 'green'),
                colored('{:.4f}'.format(output.local_target_loss.item()),
                        'green'),
                colored('{:.4f}'.format(output.remote_target_loss.item()),
                        'blue'),
                colored('{:.4f}'.format(output.distillation_loss.item()),
                        'red'), self.neuron.axon, self.neuron.dendrite)
            logger.info('Codes: {}', output.dendrite.return_codes.tolist())

            self.tensorboard.add_scalar('Neuron/Rloss',
                                        output.remote_target_loss.item(),
                                        self.global_step)
            self.tensorboard.add_scalar('Neuron/Lloss',
                                        output.local_target_loss.item(),
                                        self.global_step)
            self.tensorboard.add_scalar('Neuron/Dloss',
                                        output.distillation_loss.item(),
                                        self.global_step)

            # ---- Step increments ----
            self.global_step += 1
            self.training_loss += output.local_target_loss.item()

            # --- Memory clean up ----
            torch.cuda.empty_cache()
            del output
示例#14
0
class Miner( bittensor.miner.Miner ):

    def __init__(self, config: Munch = None, **kwargs):
        if config == None:
            config = Miner.default_config();       
        bittensor.config.Config.update_with_kwargs(config.miner, kwargs) 
        Miner.check_config(config)
        self.config = config

        # ---- Model ----
        self.model = BertMLMSynapse( self.config )

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = self.config.miner.learning_rate, momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(BertMLMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('ag_news')['train']
        # The collator accepts a list [ dict{'input_ids, ...; } ] where the internal dict 
        # is produced by the tokenizer.
        self.data_collator = DataCollatorForLanguageModeling (
            tokenizer=bittensor.__tokenizer__(), mlm=True, mlm_probability=0.15
        )
        super( Miner, self ).__init__( self.config, **kwargs )

    @staticmethod
    def default_config() -> Munch:
        parser = argparse.ArgumentParser(); 
        Miner.add_args(parser) 
        config = bittensor.config.Config.to_config(parser); 
        return config

    @staticmethod
    def check_config(config: Munch):
        assert config.miner.momentum > 0 and config.miner.momentum < 1, "momentum must be a value between 0 and 1"
        assert config.miner.batch_size_train > 0, "batch_size_train must a positive value"
        assert config.miner.learning_rate > 0, "learning_rate must be a positive value."
        BertMLMSynapse.check_config( config )
        bittensor.miner.Miner.check_config( config )

    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        parser.add_argument('--miner.learning_rate', default=0.01, type=float, help='Training initial learning rate.')
        parser.add_argument('--miner.momentum', default=0.98, type=float, help='Training initial momentum for SGD.')
        parser.add_argument('--miner.clip_gradients', default=0.8, type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.')
        parser.add_argument('--miner.n_epochs', default=int(sys.maxsize), type=int, help='Number of training epochs.')
        parser.add_argument('--miner.epoch_length', default=500, type=int, help='Iterations of training per epoch')
        parser.add_argument('--miner.batch_size_train', default=1, type=int, help='Training batch size.')
        parser.add_argument('--miner.name', default='bert_mlm', type=str, help='Trials for this miner go in miner.root / (wallet_cold - wallet_hot) / miner.name ')
        BertMLMSynapse.add_args(parser)
        bittensor.miner.Miner.add_args(parser)

    # --- Main loop ----
    def run (self):

        # ---- Subscribe ----
        with self:

            # ---- Weights ----
            self.row = self.metagraph.row

            # --- Run state ---
            self.global_step = 0
            self.best_train_loss = math.inf

            # --- Loop for epochs ---
            for self.epoch in range(self.config.miner.n_epochs):
                try:
                    # ---- Serve ----
                    self.axon.serve( self.model )

                    # ---- Train Model ----
                    self.train()
                    self.scheduler.step()

                    # If model has borked for some reason, we need to make sure it doesn't emit weights
                    # Instead, reload into previous version of model
                    if torch.any(torch.isnan(torch.cat([param.view(-1) for param in self.model.parameters()]))):
                        self.model, self.optimizer = self.model_toolbox.load_model(self.config)    
                        continue

                    # ---- Emitting weights ----
                    self.metagraph.set_weights(self.row, wait_for_inclusion = True) # Sets my row-weights on the chain.

                    # ---- Sync metagraph ----
                    self.metagraph.sync() # Pulls the latest metagraph state (with my update.)
                    self.row = self.metagraph.row
                    logger.info(self.metagraph)

                    # ---- Update Tensorboard ----
                    self.dendrite.__to_tensorboard__(self.tensorboard, self.global_step)
                    self.metagraph.__to_tensorboard__(self.tensorboard, self.global_step)
                    self.axon.__to_tensorboard__(self.tensorboard, self.global_step)
                
                    # ---- Save best loss and model ----
                    if self.training_loss and self.epoch % 10 == 0:
                        if self.training_loss < self.best_train_loss:
                            self.best_train_loss = self.training_loss # update best train loss
                            self.model_toolbox.save_model(
                                self.config.miner.full_path,
                                {
                                    'epoch': self.epoch, 
                                    'model_state_dict': self.model.state_dict(), 
                                    'loss': self.best_train_loss,
                                    'optimizer_state_dict': self.optimizer.state_dict(),
                                }
                            )
                            self.tensorboard.add_scalar('Neuron/Train_loss', self.training_loss, self.global_step)
                    
                # --- Catch Errors ----
                except Exception as e:
                    logger.error('Exception in training script with error: {}', e)
                    logger.info(traceback.print_exc())
                    logger.info('Continuing to train.')
                    time.sleep(1)
    
    # ---- Train Epoch ----
    def train(self):
        self.training_loss = 0.0
        for local_step in range(self.config.miner.epoch_length):
            # ---- Forward pass ----
            inputs, targets = mlm_batch(self.dataset, self.config.miner.batch_size_train, bittensor.__tokenizer__(), self.data_collator)
            output = self.model.remote_forward (
                    self,
                    inputs = inputs.to(self.model.device), 
                    targets = targets.to(self.model.device)
            )

            # ---- Backward pass ----
            loss = output.local_target_loss + output.distillation_loss + output.remote_target_loss
            loss.backward() # Accumulates gradients on the model.
            clip_grad_norm_(self.model.parameters(), self.config.miner.clip_gradients) # clip model gradients
            self.optimizer.step() # Applies accumulated gradients.
            self.optimizer.zero_grad() # Zeros out gradients for next accummulation

            # ---- Train row weights ----
            batch_weights = torch.mean(output.router.weights, axis = 0) # Average over batch.
            self.row = (1 - 0.03) * self.row + 0.03 * batch_weights # Moving avg update.
            self.row = F.normalize(self.row, p = 1, dim = 0) # Ensure normalization.

            # ---- Step logs ----
            logger.info('GS: {} LS: {} Epoch: {}\tLocal Target Loss: {}\tRemote Target Loss: {}\tDistillation Loss: {}\tAxon: {}\tDendrite: {}',
                    colored('{}'.format(self.global_step), 'red'),
                    colored('{}'.format(local_step), 'blue'),
                    colored('{}'.format(self.epoch), 'green'),
                    colored('{:.4f}'.format(output.local_target_loss.item()), 'green'),
                    colored('{:.4f}'.format(output.remote_target_loss.item()), 'blue'),
                    colored('{:.4f}'.format(output.distillation_loss.item()), 'red'),
                    self.axon,
                    self.dendrite)
            logger.info('Codes: {}', output.router.return_codes.tolist())
            
            self.tensorboard.add_scalar('Neuron/Rloss', output.remote_target_loss.item(), self.global_step)
            self.tensorboard.add_scalar('Neuron/Lloss', output.local_target_loss.item(), self.global_step)
            self.tensorboard.add_scalar('Neuron/Dloss', output.distillation_loss.item(), self.global_step)

            # ---- Step increments ----
            self.global_step += 1
            self.training_loss += output.local_target_loss.item()

            # --- Memory clean up ----
            torch.cuda.empty_cache()
            del output
示例#15
0
def train(args, train_dataset, model, tokenizer):
    tb_writer = SummaryWriter()
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    ## DATALOADER
    train_sampler = SequentialSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    graph_train_dataloader_a, graph_train_dataloader_b = load_graph_examples(
        args)
    args.logging_steps = len(train_dataloader)
    args.save_steps = len(train_dataloader)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_total_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    assert len(train_dataset) == len(graph_train_dataloader_a.dataset)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size = %d",
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    graph_optimizer = AdamW(model.graph_encoder.parameters(),
                            lr=args.graph_lr,
                            weight_decay=args.weight_decay)
    linear_optimizer = AdamW(model.classifier.parameters(),
                             lr=args.learning_rate,
                             weight_decay=args.weight_decay)
    linear_c_optimizer = AdamW(model.classifier_c.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
    linear_type_optimizer = AdamW(model.classifier_type.parameters(),
                                  lr=args.learning_rate,
                                  weight_decay=args.weight_decay)

    # bert_optimizer_grouped_parameters = get_bert_param_groups(model, args)
    bert_optimizer_grouped_parameters = get_bert_param_groups(
        model.text_encoder, args)
    bert_optimizer = AdamW(bert_optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           eps=args.adam_epsilon,
                           weight_decay=args.weight_decay)
    if args.scheduler == 'linear':
        scheduler = WarmupLinearSchedule(bert_optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
        graph_scheduler = WarmupLinearSchedule(graph_optimizer,
                                               warmup_steps=args.warmup_steps,
                                               t_total=t_total)
        linear_scheduler = WarmupLinearSchedule(linear_optimizer,
                                                warmup_steps=args.warmup_steps,
                                                t_total=t_total)
        linear_c_scheduler = WarmupLinearSchedule(
            linear_c_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total)
        linear_type_scheduler = WarmupLinearSchedule(
            linear_type_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total)
    elif args.scheduler == 'cosine':
        scheduler = WarmupCosineWithHardRestartsSchedule(
            bert_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total,
            cycles=2.)
        graph_scheduler = WarmupCosineWithHardRestartsSchedule(
            graph_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total,
            cycles=2.)
        linear_scheduler = WarmupCosineWithHardRestartsSchedule(
            linear_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total,
            cycles=2.)
        linear_c_scheduler = WarmupCosineWithHardRestartsSchedule(
            linear_c_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total,
            cycles=2.)
        linear_type_scheduler = WarmupCosineWithHardRestartsSchedule(
            linear_type_optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total,
            cycles=2.)

    ## TRAIN
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(args)
    for _ in trange(int(args.num_train_epochs), desc='Epoch'):
        for batch, data_a, data_b in tqdm(zip(train_dataloader,
                                              graph_train_dataloader_a,
                                              graph_train_dataloader_b),
                                          desc='Iteration',
                                          total=len(train_dataloader)):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            data_a, data_b = data_a.to(args.device), data_b.to(args.device)

            loss_fcts = {
                'mse': F.mse_loss,
                'smooth_l1': F.smooth_l1_loss,
                'l1': F.l1_loss
            }
            loss_fct = loss_fcts[args.loss_fct]
            loss_combined = 0.0
            for mode in ['medsts', 'medsts_c', 'medsts_type']:
                torch.cuda.empty_cache()
                logits = model(batch[0],
                               batch[1],
                               batch[2],
                               data_a,
                               data_b,
                               mode=mode)
                if mode == 'medsts':
                    loss = loss_fct(logits, data_a.label)
                elif mode == 'medsts_c':
                    loss = F.cross_entropy(logits, data_a.label_c)
                elif mode == 'medsts_type':
                    loss = F.cross_entropy(logits, data_a.label_type)
                loss_combined += loss

            loss_combined.backward()
            if args.clip_all:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.text_encoder.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            scheduler.step()
            bert_optimizer.step()

            graph_scheduler.step()
            linear_scheduler.step()
            graph_optimizer.step()

            # print('learning rate: {} \t graph optimizer lr: {}'.format(linear_optimizer.param_groups[0]['lr'], graph_optimizer.param_groups[0]['lr']))
            linear_optimizer.step()

            linear_c_scheduler.step()
            linear_type_scheduler.step()
            linear_c_optimizer.step()
            linear_type_optimizer.step()

            model.zero_grad()
            global_step += 1
            args.logging_steps = len(train_dataloader) // 4

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                result = evaluate(args, model, tokenizer)
                tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                     args.logging_steps, global_step)
                logging_loss = tr_loss

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                # model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                # model_to_save.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                # logger.info("Saving model checkpoint to %s", output_dir)

        # result = evaluate(args, model, tokenizer)
    tb_writer.close()
    return global_step, tr_loss / global_step
示例#16
0
class Miner():

    def __init__(self, config: Munch = None):
        if config == None:
            config = Miner.build_config(); logger.info(bittensor.config.Config.toString(config))
        self.config = config

        # ---- Neuron ----
        self.neuron = bittensor.neuron.Neuron(self.config)

        # ---- Model ----
        self.model = BertMLMSynapse( self.config )

        # ---- Optimizer ----
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = self.config.miner.learning_rate, momentum=self.config.miner.momentum)
        self.scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, 50, 300)

        # ---- Model Load/Save tools ----
        self.model_toolbox = ModelToolbox(BertMLMSynapse, torch.optim.SGD)

        # ---- Dataset ----
        # Dataset: 74 million sentences pulled from books.
        self.dataset = load_dataset('ag_news')['train']
        # The collator accepts a list [ dict{'input_ids, ...; } ] where the internal dict 
        # is produced by the tokenizer.
        self.data_collator = DataCollatorForLanguageModeling (
            tokenizer=bittensor.__tokenizer__(), mlm=True, mlm_probability=0.15
        )

        # ---- Logging ----
        self.tensorboard = SummaryWriter(log_dir = self.config.miner.full_path)
        if self.config.miner.record_log:
            logger.add(self.config.miner.full_path + "/{}_{}.log".format(self.config.miner.name, self.config.miner.trial_uid),format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")

    @staticmethod
    def build_config() -> Munch:
        parser = argparse.ArgumentParser(); 
        Miner.add_args(parser) 
        config = bittensor.config.Config.to_config(parser); 
        Miner.check_config(config)
        return config

    @staticmethod
    def check_config(config: Munch):
        assert config.miner.momentum > 0 and config.miner.momentum < 1, "momentum must be a value between 0 and 1"
        assert config.miner.batch_size_train > 0, "batch_size_train must a positive value"
        assert config.miner.learning_rate > 0, "learning_rate must be a positive value."
        full_path = '{}/{}/{}'.format(config.miner.root_dir, config.miner.name, config.miner.trial_uid)
        config.miner.full_path = os.path.expanduser(full_path)
        if not os.path.exists(config.miner.full_path):
            os.makedirs(config.miner.full_path)
        BertMLMSynapse.check_config(config)
        bittensor.neuron.Neuron.check_config(config)

    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        parser.add_argument('--miner.learning_rate', default=0.01, type=float, help='Training initial learning rate.')
        parser.add_argument('--miner.momentum', default=0.98, type=float, help='Training initial momentum for SGD.')
        parser.add_argument('--miner.n_epochs', default=int(sys.maxsize), type=int, help='Number of training epochs.')
        parser.add_argument('--miner.epoch_length', default=500, type=int, help='Iterations of training per epoch')
        parser.add_argument('--miner.batch_size_train', default=1, type=int, help='Training batch size.')
        parser.add_argument('--miner.sync_interval', default=100, type=int, help='Batches before we sync with chain and emit new weights.')
        parser.add_argument('--miner.log_interval', default=10, type=int, help='Batches before we log miner info.')
        parser.add_argument('--miner.accumulation_interval', default=1, type=int, help='Batches before we apply acummulated gradients.')
        parser.add_argument('--miner.apply_remote_gradients', default=False, type=bool, help='If true, neuron applies gradients which accumulate from remotes calls.')
        parser.add_argument('--miner.root_dir', default='~/.bittensor/miners/', type=str,  help='Root path to load and save data associated with each miner')
        parser.add_argument('--miner.name', default='bert-nsp', type=str, help='Trials for this miner go in miner.root / miner.name')
        parser.add_argument('--miner.trial_uid', default=str(time.time()).split('.')[0], type=str, help='Saved models go in miner.root_dir / miner.name / miner.uid')
        parser.add_argument('--miner.record_log', default=True, help='Record all logs when running this miner')
        parser.add_argument('--miner.config_file', type=str, help='config file to run this neuron, if not using cmd line arguments.')
        BertMLMSynapse.add_args(parser)
        bittensor.neuron.Neuron.add_args(parser)

    # --- Main loop ----
    def run (self):

        # ---- Subscribe ----
        with self.neuron:

            # ---- Weights ----
            self.row = self.neuron.metagraph.row

            # --- Run state ---
            self.global_step = 0
            self.best_train_loss = math.inf

            # --- Loop for epochs ---
            for self.epoch in range(self.config.miner.n_epochs):
                try:
                    # ---- Serve ----
                    self.neuron.axon.serve( self.model )

                    # ---- Train Model ----
                    self.train()
                    self.scheduler.step()

                    # If model has borked for some reason, we need to make sure it doesn't emit weights
                    # Instead, reload into previous version of model
                    if torch.any(torch.isnan(torch.cat([param.view(-1) for param in self.model.parameters()]))):
                        self.model, self.optimizer = self.model_toolbox.load_model(self.config)    
                        continue

                    # ---- Emitting weights ----
                    self.neuron.metagraph.set_weights(self.row, wait_for_inclusion = True) # Sets my row-weights on the chain.

                    # ---- Sync metagraph ----
                    self.neuron.metagraph.sync() # Pulls the latest metagraph state (with my update.)
                    self.row = self.neuron.metagraph.row

                    # --- Epoch logs ----
                    print(self.neuron.axon.__full_str__())
                    print(self.neuron.dendrite.__full_str__())
                    print(self.neuron.metagraph)

                    # ---- Update Tensorboard ----
                    self.neuron.dendrite.__to_tensorboard__(self.tensorboard, self.global_step)
                    self.neuron.metagraph.__to_tensorboard__(self.tensorboard, self.global_step)
                    self.neuron.axon.__to_tensorboard__(self.tensorboard, self.global_step)
                
                    # ---- Save best loss and model ----
                    if self.training_loss and self.epoch % 10 == 0:
                        if self.training_loss < self.best_train_loss:
                            self.best_train_loss = self.training_loss # update best train loss
                            self.model_toolbox.save_model(
                                self.config.miner.full_path,
                                {
                                    'epoch': self.epoch, 
                                    'model_state_dict': self.model.state_dict(), 
                                    'loss': self.best_train_loss,
                                    'optimizer_state_dict': self.optimizer.state_dict(),
                                }
                            )
                            self.tensorboard.add_scalar('Neuron/Train_loss', self.training_loss, self.global_step)
                    
                # --- Catch Errors ----
                except Exception as e:
                    logger.error('Exception in training script with error: {}', e)
                    logger.info(traceback.print_exc())
                    logger.info('Continuing to train.')
                    time.sleep(1)
    
    # ---- Train Epoch ----
    def train(self):
        self.training_loss = 0.0
        for local_step in range(self.config.miner.epoch_length):
            # ---- Forward pass ----
            inputs, targets = mlm_batch(self.dataset, self.config.miner.batch_size_train, bittensor.__tokenizer__(), self.data_collator)
            output = self.model.remote_forward (
                    self.neuron,
                    inputs = inputs.to(self.model.device), 
                    targets = targets.to(self.model.device)
            )

            # ---- Backward pass ----
            loss = output.local_target_loss + output.distillation_loss + output.remote_target_loss
            loss.backward() # Accumulates gradients on the model.
            self.optimizer.step() # Applies accumulated gradients.
            self.optimizer.zero_grad() # Zeros out gradients for next accummulation

            # ---- Train row weights ----
            batch_weights = torch.mean(output.router.weights, axis = 0) # Average over batch.
            self.row = (1 - 0.03) * self.row + 0.03 * batch_weights # Moving avg update.
            self.row = F.normalize(self.row, p = 1, dim = 0) # Ensure normalization.

            # ---- Step logs ----
            logger.info('GS: {} LS: {} Epoch: {}\tLocal Target Loss: {}\tRemote Target Loss: {}\tDistillation Loss: {}\tAxon: {}\tDendrite: {}',
                    colored('{}'.format(self.global_step), 'red'),
                    colored('{}'.format(local_step), 'blue'),
                    colored('{}'.format(self.epoch), 'green'),
                    colored('{:.4f}'.format(output.local_target_loss.item()), 'green'),
                    colored('{:.4f}'.format(output.remote_target_loss.item()), 'blue'),
                    colored('{:.4f}'.format(output.distillation_loss.item()), 'red'),
                    self.neuron.axon,
                    self.neuron.dendrite)
            logger.info('Codes: {}', output.router.return_codes.tolist())
            
            self.tensorboard.add_scalar('Neuron/Rloss', output.remote_target_loss.item(), self.global_step)
            self.tensorboard.add_scalar('Neuron/Lloss', output.local_target_loss.item(), self.global_step)
            self.tensorboard.add_scalar('Neuron/Dloss', output.distillation_loss.item(), self.global_step)

            # ---- Step increments ----
            self.global_step += 1
            self.training_loss += output.local_target_loss.item()

            # --- Memory clean up ----
            torch.cuda.empty_cache()
            del output