parser.add_argument('--train_batch_size', type=int, default=88888) return parser.parse_args() if __name__ == '__main__': from extract import LexicalMap import time args = parse_config() vocabs = dict() vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS]) vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END]) vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END]) vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END]) vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END]) vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL]) lexical_mapping = LexicalMap() train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) epoch_idx = 0 batch_idx = 0 last = 0 while True: st = time.time() for d in train_data: d = move_to_device(d, torch.device('cpu')) batch_idx += 1 #if d['concept'].size(0) > 5: # continue print (epoch_idx, batch_idx, d['concept'].size(), d['token_in'].size()) c_len, bsz = d['concept'].size() t_len, bsz = d['token_in'].size()
def main(args, local_rank): vocabs = dict() vocabs['tok'] = Vocab(args.tok_vocab, 5, [CLS]) vocabs['lem'] = Vocab(args.lem_vocab, 5, [CLS]) vocabs['pos'] = Vocab(args.pos_vocab, 5, [CLS]) vocabs['ner'] = Vocab(args.ner_vocab, 5, [CLS]) vocabs['predictable_concept'] = Vocab(args.predictable_concept_vocab, 10, [DUM, END]) vocabs['concept'] = Vocab(args.concept_vocab, 5, [DUM, END]) vocabs['rel'] = Vocab(args.rel_vocab, 50, [NIL]) vocabs['word_char'] = Vocab(args.word_char_vocab, 100, [CLS, END]) vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [CLS, END]) lexical_mapping = LexicalMap(args.lexical_mapping) if args.pretrained_word_embed is not None: vocab, pretrained_embs = load_pretrained_word_embed( args.pretrained_word_embed) vocabs['glove'] = vocab else: pretrained_embs = None for name in vocabs: print((name, vocabs[name].size)) torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) device = torch.device('cuda', local_rank) #print(device) #exit() model = Parser(vocabs, args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim, pretrained_embs, device=device) if args.world_size > 1: torch.manual_seed(19940117 + dist.get_rank()) torch.cuda.manual_seed_all(19940117 + dist.get_rank()) random.seed(19940117 + dist.get_rank()) model = model.cuda(local_rank) train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=True) train_data.set_unk_rate(args.unk_rate) weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{ 'params': weight_decay_params, 'weight_decay': 1e-4 }, { 'params': no_weight_decay_params, 'weight_decay': 0. }] optimizer = AdamWeightDecayOptimizer(grouped_params, lr=args.lr, betas=(0.9, 0.999), eps=1e-6) batches_acm, loss_acm, concept_loss_acm, arc_loss_acm, rel_loss_acm = 0, 0, 0, 0, 0 #model.load_state_dict(torch.load('./ckpt/epoch297_batch49999')['model']) discarded_batches_acm = 0 queue = mp.Queue(10) train_data_generator = mp.Process(target=data_proc, args=(train_data, queue)) train_data_generator.start() used_batches = 0 if args.resume_ckpt: ckpt = torch.load(args.resume_ckpt) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) batches_acm = ckpt['batches_acm'] del ckpt model.train() epoch = 0 while True: batch = queue.get() #print("epoch",epoch) #print("batches_acm",batches_acm) #print("used_batches",used_batches) if isinstance(batch, str): epoch += 1 print('epoch', epoch, 'done', 'batches', batches_acm) else: batch = move_to_device(batch, model.device) concept_loss, arc_loss, rel_loss = model(batch) loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update loss_value = loss.item() concept_loss_value = concept_loss.item() arc_loss_value = arc_loss.item() rel_loss_value = rel_loss.item() if batches_acm > args.warmup_steps and arc_loss_value > 5. * ( arc_loss_acm / batches_acm): discarded_batches_acm += 1 print('abnormal', concept_loss.item(), arc_loss.item(), rel_loss.item()) continue loss_acm += loss_value concept_loss_acm += concept_loss_value arc_loss_acm += arc_loss_value rel_loss_acm += rel_loss_value loss.backward() used_batches += 1 if not (used_batches % args.batches_per_update == -1 % args.batches_per_update): continue batches_acm += 1 if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) update_lr(optimizer, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() if args.world_size == 1 or (dist.get_rank() == 0): if batches_acm % args.print_every == -1 % args.print_every: print( 'Train Epoch %d, Batch %d, Discarded Batch %d, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f' % (epoch, batches_acm, discarded_batches_acm, concept_loss_acm / batches_acm, arc_loss_acm / batches_acm, rel_loss_acm / batches_acm)) model.train() if batches_acm % args.eval_every == -1 % args.eval_every: model.eval() torch.save( { 'args': args, 'model': model.state_dict(), 'batches_acm': batches_acm, 'optimizer': optimizer.state_dict() }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm)) model.train()
def main(args, local_rank): vocabs = dict() vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS]) vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END]) vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END]) vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END]) vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END]) vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL]) lexical_mapping = LexicalMap() for name in vocabs: print((name, vocabs[name].size, vocabs[name].coverage)) torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) #device = torch.device('cpu') device = torch.device('cuda', local_rank) model = Generator(vocabs, args.token_char_dim, args.token_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.rel_dim, args.rnn_hidden_size, args.rnn_num_layers, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.pretrained_file, device) if args.world_size > 1: torch.manual_seed(19940117 + dist.get_rank()) torch.cuda.manual_seed_all(19940117 + dist.get_rank()) random.seed(19940117 + dist.get_rank()) model = model.to(device) train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) #dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=False) train_data.set_unk_rate(args.unk_rate) weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{ 'params': weight_decay_params, 'weight_decay': 1e-4 }, { 'params': no_weight_decay_params, 'weight_decay': 0. }] optimizer = AdamWeightDecayOptimizer(grouped_params, lr=args.lr, betas=(0.9, 0.999), eps=1e-6) batches_acm, loss_acm = 0, 0 discarded_batches_acm = 0 queue = mp.Queue(10) train_data_generator = mp.Process(target=data_proc, args=(train_data, queue)) train_data_generator.start() model.train() epoch = 0 while batches_acm < args.total_train_steps: batch = queue.get() if isinstance(batch, str): epoch += 1 print('epoch', epoch, 'done', 'batches', batches_acm) continue batch = move_to_device(batch, device) loss = model(batch) exit(0) loss_value = loss.item() if batches_acm > args.warmup_steps and loss_value > 5. * (loss_acm / batches_acm): discarded_batches_acm += 1 print('abnormal', loss_value) continue loss_acm += loss_value batches_acm += 1 loss.backward() if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) update_lr(optimizer, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() #------------ if args.world_size == 1 or (dist.get_rank() == 0): if batches_acm % args.print_every == -1 % args.print_every: print( 'Train Epoch %d, Batch %d, Discarded Batch %d, loss %.3f' % (epoch, batches_acm, discarded_batches_acm, loss_acm / batches_acm)) model.train() if batches_acm > args.warmup_steps and batches_acm % args.eval_every == -1 % args.eval_every: #model.eval() #bleu, chrf = validate(model, dev_data) #print ("epoch", "batch", "bleu", "chrf") #print (epoch, batches_acm, bleu, chrf) torch.save({ 'args': args, 'model': model.state_dict() }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm)) model.train()