def prepare_incremental_input(self, step_seq): conc = ListsToTensor(step_seq, self.vocabs['concept']) conc_char = ListsofStringToTensor(step_seq, self.vocabs['concept_char']) conc, conc_char = move_to_device(conc, self.device), move_to_device( conc_char, self.device) return conc, conc_char
def show_progress(model, dev_data): model.eval() loss_acm = 0. for batch in dev_data: batch = move_to_device(batch, model.device) concept_loss, arc_loss, rel_loss = model(batch) loss = concept_loss + arc_loss + rel_loss loss_acm += loss.item() print ('total loss', loss_acm) return loss_acm
def parse_data(model, pp, data, input_file, output_file, beam_size=8, alpha=0.6, max_time_step=100): tot = 0 with open(output_file, 'w') as fo: for batch in data: batch = move_to_device(batch, model.device) res = parse_batch(model, batch, beam_size, alpha, max_time_step) for concept, relation, score in zip(res['concept'], res['relation'], res['score']): fo.write('# ::conc '+ ' '.join(concept)+'\n') fo.write('# ::score %.6f\n'%score) fo.write(pp.postprocess(concept, relation)+'\n\n') tot += 1 match(output_file, input_file) print ('write down %d amrs'%tot)
def main(local_rank, args): vocabs, lexical_mapping = load_vocabs(args) bert_encoder = None if args.with_bert: bert_encoder = BertEncoder.from_pretrained(args.bert_path) for p in bert_encoder.parameters(): p.requires_grad = False torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) model = Parser(vocabs, args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim, args.pretrained_file, bert_encoder, device) if args.world_size > 1: torch.manual_seed(19940117 + dist.get_rank()) torch.cuda.manual_seed_all(19940117 + dist.get_rank()) random.seed(19940117 + dist.get_rank()) model = model.cuda(local_rank) dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=False) pp = PostProcessor(vocabs['rel']) weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{ 'params': weight_decay_params, 'weight_decay': 1e-4 }, { 'params': no_weight_decay_params, 'weight_decay': 0. }] optimizer = AdamWeightDecayOptimizer(grouped_params, 1., betas=(0.9, 0.999), eps=1e-6) used_batches = 0 batches_acm = 0 if args.resume_ckpt: ckpt = torch.load(args.resume_ckpt) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) batches_acm = ckpt['batches_acm'] del ckpt train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) train_data.set_unk_rate(args.unk_rate) queue = mp.Queue(10) train_data_generator = mp.Process(target=data_proc, args=(train_data, queue)) train_data_generator.start() model.train() epoch, loss_avg, concept_loss_avg, arc_loss_avg, rel_loss_avg = 0, 0, 0, 0, 0 while True: batch = queue.get() if isinstance(batch, str): epoch += 1 print('epoch', epoch, 'done', 'batches', batches_acm) else: batch = move_to_device(batch, model.device) concept_loss, arc_loss, rel_loss, graph_arc_loss = model(batch) loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update loss_value = loss.item() concept_loss_value = concept_loss.item() arc_loss_value = arc_loss.item() rel_loss_value = rel_loss.item() loss_avg = loss_avg * args.batches_per_update * 0.8 + 0.2 * loss_value concept_loss_avg = concept_loss_avg * 0.8 + 0.2 * concept_loss_value arc_loss_avg = arc_loss_avg * 0.8 + 0.2 * arc_loss_value rel_loss_avg = rel_loss_avg * 0.8 + 0.2 * rel_loss_value loss.backward() used_batches += 1 if not (used_batches % args.batches_per_update == -1 % args.batches_per_update): continue batches_acm += 1 if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) lr = update_lr(optimizer, args.lr_scale, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() if args.world_size == 1 or (dist.get_rank() == 0): if batches_acm % args.print_every == -1 % args.print_every: print( 'Train Epoch %d, Batch %d, LR %.6f, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f' % (epoch, batches_acm, lr, concept_loss_avg, arc_loss_avg, rel_loss_avg)) model.train() if ( batches_acm > 10000 or args.resume_ckpt is not None ) and batches_acm % args.eval_every == -1 % args.eval_every: model.eval() parse_data( model, pp, dev_data, args.dev_data, '%s/epoch%d_batch%d_dev_out' % (args.ckpt, epoch, batches_acm)) torch.save( { 'args': args, 'model': model.state_dict(), 'batches_acm': batches_acm, 'optimizer': optimizer.state_dict() }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm)) model.train()