Beispiel #1
0
        "overwrite_output_dir": True,
        "max_seq_length": args.max_seq_length,
        "train_batch_size": args.train_batch_size,
        "eval_batch_size": args.eval_batch_size,
        "evaluate_during_training": False,  # Disabled for FedAvg.
        "evaluate_during_training_steps": args.evaluate_during_training_steps,
        "fp16": args.fp16,
        "data_file_path": args.data_file_path,
        "partition_file_path": args.partition_file_path,
        "partition_method": args.partition_method,
        "dataset": args.dataset,
        "output_dir": args.output_dir,
        "is_debug_mode": args.is_debug_mode
    })

    model_config, client_model, tokenizer = create_model(
        model_args, formulation="span_extraction")

    client_trainer = SpanExtractionTrainer(model_args, device, client_model,
                                           None, None, tokenizer)
    fed_trainer = FedTransformerTrainer(client_trainer, client_model)

    # data loading and management
    preprocessor = TLMPreprocessor(args=model_args, tokenizer=tokenizer)
    num_workers = args.client_num_per_round
    dm = SpanExtractionDataManager(args, model_args, preprocessor, process_id,
                                   num_workers)
    train_data_num, train_data_global, test_data_global, train_data_local_num_dict, \
    train_data_local_dict, test_data_local_dict, num_clients = dm.load_federated_data(process_id=process_id)
    if process_id == 0:
        client_trainer.test_dl = test_data_global
Beispiel #2
0
                                 "overwrite_output_dir": True,
                                 "max_seq_length": args.max_seq_length,
                                 "train_batch_size": args.train_batch_size,
                                 "eval_batch_size": args.eval_batch_size,
                                 "evaluate_during_training": False,  # Disabled for FedAvg.
                                 "evaluate_during_training_steps": args.evaluate_during_training_steps,
                                 "fp16": args.fp16,
                                 "data_file_path": args.data_file_path,
                                 "partition_file_path": args.partition_file_path,
                                 "partition_method": args.partition_method,
                                 "dataset": args.dataset,
                                 "output_dir": args.output_dir,
                                 "is_debug_mode": args.is_debug_mode
                                 })
    model_args.config["num_labels"] = num_labels
    model_config, client_model, tokenizer = create_model(
        model_args, formulation="classification")

    # trainer
    client_trainer = TextClassificationTrainer(
        model_args, device, client_model, None, None)
    fed_trainer = FedTransformerTrainer(client_trainer, client_model)

    # data manager
    preprocessor = TLMPreprocessor(
        args=model_args, label_vocab=attributes["label_vocab"],
        tokenizer=tokenizer)
    dm = TextClassificationDataManager(args, model_args, preprocessor, process_id, args.client_num_per_round)
    train_data_num, train_data_global, test_data_global, train_data_local_num_dict, \
    train_data_local_dict, test_data_local_dict, num_clients = dm.load_federated_data(process_id=process_id)

    # start FedAvg algorithm
Beispiel #3
0
        "reprocess_input_data": True,
        "overwrite_output_dir": True,
        "max_seq_length": args.max_seq_length,
        "train_batch_size": args.train_batch_size,
        "eval_batch_size": args.eval_batch_size,
        "evaluate_during_training": False,  # Disabled for FedAvg.
        "evaluate_during_training_steps": args.evaluate_during_training_steps,
        "fp16": args.fp16,
        "data_file_path": args.data_file_path,
        "partition_file_path": args.partition_file_path,
        "partition_method": args.partition_method,
        "dataset": args.dataset,
        "output_dir": args.output_dir,
        "is_debug_mode": args.is_debug_mode
    })
    model_config, client_model, tokenizer = create_model(model_args,
                                                         formulation="seq2seq")

    # trainer
    client_trainer = Seq2SeqTrainer(model_args, device, client_model, None,
                                    None, tokenizer)
    fed_trainer = FedTransformerTrainer(client_trainer, client_model)

    # data manager
    preprocessor = TLMPreprocessor(args=model_args, tokenizer=tokenizer)
    dm = Seq2SeqDataManager(args, model_args, preprocessor, process_id,
                            args.client_num_per_round)
    train_data_num, train_data_global, test_data_global, train_data_local_num_dict, \
    train_data_local_dict, test_data_local_dict, num_clients = dm.load_federated_data(process_id=process_id)

    # start FedAvg algorithm
    # for distributed algorithm, train_data_gloabl and test_data_global are required