class Classifier(nn.Module): def __init__(self, lsz, args): super().__init__() self.bert = BERT(args) self.sent_predict = nn.Linear(args.d_model, lsz) self.sent_predict.weight.data.normal_(INIT_RANGE) self.sent_predict.bias.data.zero_() def get_trainable_parameters(self): return self.bert.get_trainable_parameters() def forward(self, inp, pos, segment_label): _, sent_encode = self.bert(inp, pos, segment_label) return F.log_softmax(self.sent_predict(sent_encode), dim=-1) def load_model(self, path="model.pt"): data = torch.load(path) self.bert.load_model(data["weights"])
def main(args): assert torch.cuda.is_available(), "need to use GPUs" use_cuda = torch.cuda.is_available() cuda_devices = list(map(int, args.cuda_devices.split(","))) is_multigpu = len(cuda_devices) > 1 device = "cuda" random.seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed(args.seed) if is_multigpu > 1: torch.cuda.manual_seed_all(args.seed) data = torch.load(args.data) dataset = BERTDataSet(data['word'], data['max_len'], data["dict"], args.batch_size * args.steps) training_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_cpus) model = BERT(dataset.word_size, data["max_len"], args.n_stack_layers, args.d_model, args.d_ff, args.n_head, args.dropout) print( f"BERT have {sum(x.numel() for x in model.parameters())} paramerters in total" ) optimizer = ScheduledOptim( torch.nn.DataParallel( torch.optim.Adam(model.get_trainable_parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-09, weight_decay=0.01), device_ids=cuda_devices), args.d_model, args.n_warmup_steps) w_criterion = WordCrossEntropy() w_criterion = w_criterion.to(device) s_criterion = torch.nn.CrossEntropyLoss() model = model.to(device) model = torch.nn.DataParallel(model, device_ids=cuda_devices) model.train() for step, datas in enumerate(training_data): inp, pos, sent_label, word_label, segment_label = list( map(lambda x: x.to(device), datas)) sent_label = sent_label.view(-1) optimizer.zero_grad() word, sent = model(inp, pos, segment_label) w_loss, w_corrects, tgt_sum = w_criterion(word, word_label) s_loss = s_criterion(sent, sent_label) if is_multigpu: w_loss, s_loss = w_loss.mean(), s_loss.mean() loss = w_loss + s_loss loss.backward() optimizer.step() s_corrects = (torch.max(sent, 1)[1].data == sent_label.data).sum() print( f"[Step {step+1}/{args.steps}] [word_loss: {w_loss:.5f}, sent_loss: {s_loss:.5f}, loss: {loss:.5f}, w_pre: {w_corrects/tgt_sum*100:.2f}% {w_corrects}/{tgt_sum}, s_pre: {float(s_corrects)/args.batch_size*100:.2f}% {s_corrects}/{args.batch_size}]" ) if tf is not None: add_summary_value("Word loss", w_loss, step) add_summary_value("Sent loss", s_loss, step) add_summary_value("Loss", loss, step) add_summary_value("Word predict", w_corrects / tgt_sum, step) add_summary_value("Sent predict", float(s_corrects) / args.batch_size, step) tf_summary_writer.flush()