def train(): args = set_up_args() task2plan_train = load_task_and_plan_json(args, "train") # Initializing a BERT bert-base-uncased style configuration config_encoder = BertConfig() config_decoder = BertConfig() config_decoder.update({ "vocab_size": len(decoder_tokenizer.vocab), "num_hidden_layers": 6, "num_attention_heads": 6 })
task2plan_valid_seen = load_task_and_plan_json(args, "valid_seen") tpd_valid_seen = TaskPlanDataset(task2plan_valid_seen) valid_seen_dataloader = DataLoader(tpd_valid_seen, batch_size=1, shuffle = True) # Initializing a BERT bert-base-uncased style configuration config_encoder = BertConfig() config_decoder = BertConfig() encoder_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') decoder_tokenizer = BasicTokenizer("yz/vocab.txt") config_decoder.update({ "vocab_size": len(decoder_tokenizer.vocab), "num_hidden_layers":3, "num_attention_heads":3 }) config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) # 导入模型 BERT model = EncoderDecoderModel(config=config) model.encoder = BertModel.from_pretrained('bert-base-uncased') if args.gpu and torch.cuda.is_available(): model = model.cuda() loss_fun = nn.CrossEntropyLoss() #loss_fun = nn.CrossEntropyLoss(ignore_index=0) optimizer = torch.optim.Adam(model.decoder.parameters(), lr=args.lr)