示例#1
0
 def create_dataset(args):
     # a simple copy of main bert.py until the dataset creation
     config = BertConfig()
     model = Bert(config, builder=popart.Builder())
     indices, positions, segments, masks, labels = bert_add_inputs(
         args, model)
     inputs = [indices, positions, segments, masks, labels]
     embedding_dict, positional_dict = model.get_model_embeddings()
     shapeOf = model.builder.getTensorShape
     inputs = reduce(chain, inputs[3:], inputs[:3])
     tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]
     dataset = get_bert_dataset(tensor_shapes,
                                input_file=args.input_files,
                                output_dir=args.output_dir,
                                sequence_length=args.sequence_length,
                                vocab_file=args.vocab_file,
                                vocab_length=args.vocab_length,
                                batch_size=args.batch_size,
                                batches_per_step=args.batches_per_step,
                                embedding_dict=embedding_dict,
                                positional_dict=positional_dict,
                                generated_data=args.generated_data,
                                is_training=False,
                                no_drop_remainder=True,
                                shuffle=args.shuffle,
                                mpi_size=args.mpi_size,
                                is_distributed=(args.mpi_size > 1))
     return dataset
示例#2
0
 def create_dataset(args):
     # a simple copy of main bert.py until the dataset creation
     config = BertConfig()
     model = Bert(config, builder=popart.Builder())
     indices, positions, segments, masks, labels = bert_add_inputs(
         args, model)
     inputs = [indices, positions, segments, masks, labels]
     embedding_dict, positional_dict = model.get_model_embeddings()
     shapeOf = model.builder.getTensorShape
     inputs = reduce(chain, inputs[3:], inputs[:3])
     tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]
     dataset = get_bert_dataset(args, tensor_shapes)
     return dataset