コード例 #1
0
    def finetune():
        if run_finetune.button_style == 'success':
            # user double clicked before start of function
            return

        handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_finetune.disabled = True
        run_finetune.button_style = 'primary'
        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle = handle.lower()
        log_finetune.clear_output(wait=True)
        clear_output(wait=True)

        success_try = False

        with log_finetune:
            print_html(f'\nTraining Neural Network on {handle_widget.value.strip()} tweets... This could take up to 2-3 minutes!\n')
            progress = widgets.FloatProgress(value=0.1, min=0.0, max=1.0, bar_style = 'info')
            display(progress)

        with log_debug:
            try:
                # use new run id
                run_id = wandb.util.generate_id()
                %env WANDB_RUN_ID=$run_id
                run_name = handle_widget.value.strip()
                %env WANDB_NAME=$run_name
                wandb.init(config={'version':0.1})
                
                # Setting up pre-trained neural network
                with log_finetune:
                    print_html('\nSetting up pre-trained neural network...')
                global trainer, tokenizer
                config = AutoConfig.from_pretrained('gpt2')
                tokenizer = AutoTokenizer.from_pretrained('gpt2')
                model = AutoModelWithLMHead.from_pretrained('gpt2', config=config)
                block_size = tokenizer.max_len
                train_dataset = TextDataset(tokenizer=tokenizer, file_path=f'data_{handle}_train.txt', block_size=block_size, overwrite_cache=True)
                data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
                epochs = 4  # limit before overfitting
                training_args = TrainingArguments(
                    output_dir=f'output/{handle}',
                    overwrite_output_dir=True,
                    do_train=True,
                    num_train_epochs=epochs,
                    per_gpu_train_batch_size=1,
                    logging_steps=5,
                    save_steps=0,
                    seed=random.randint(0,2**32-1))
                trainer = Trainer(
                    model=model,
                    args=training_args,
                    data_collator=data_collator,
                    train_dataset=train_dataset,
                    prediction_loss_only=True)
                progress.value = 0.4
                
                p_start, p_end = 0.4, 1.
                def progressify(f):
                    "Control progress bar when calling f"
                    def inner(*args, **kwargs):
                        if trainer.epoch is not None:
                            progress.value = p_start + trainer.epoch / epochs * (p_end - p_start)
                        return f(*args, **kwargs)
                    return inner
        
                trainer._training_step = progressify(trainer._training_step)
                
                # Training neural network
                with log_finetune:
                    print_html('Training neural network...\n')
                    display(wandb.jupyter.Run())
                    print_html('\n')
                    display(progress)
                trainer.train()

                run_finetune.button_style = 'success'
                run_predictions.disabled = False

                progress.value = 1.0
                progress.bar_style = 'success'
                success_try = True

                with log_finetune:
                    print_html('\n🎉 Neural network trained successfully!')
                log_predictions.clear_output(wait=True)
                with log_predictions:
                    print_html('\nEnter the start of a sentence and click "Run predictions"')
                with log_restart:
                    print_html('\n<b>To change user, click on menu "Runtime" → "Restart and run all"</b>\n')

            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_finetune.button_style = 'danger'
                run_finetune.disabled = False
                            
        if not success_try:
            display(log_debug)
            progress.bar_style = 'danger'