示例#1
0
 def prepare_incremental_input(self, step_seq):
     conc = ListsToTensor(step_seq, self.vocabs['concept'])
     conc_char = ListsofStringToTensor(step_seq,
                                       self.vocabs['concept_char'])
     conc, conc_char = move_to_device(conc, self.device), move_to_device(
         conc_char, self.device)
     return conc, conc_char
示例#2
0
def show_progress(model, dev_data):
    model.eval()
    loss_acm = 0.
    for batch in dev_data:
        batch = move_to_device(batch, model.device)
        concept_loss, arc_loss, rel_loss = model(batch)
        loss = concept_loss + arc_loss + rel_loss
        loss_acm += loss.item()
    print ('total loss', loss_acm)
    return loss_acm
示例#3
0
def parse_data(model, pp, data, input_file, output_file, beam_size=8, alpha=0.6, max_time_step=100):
    tot = 0
    with open(output_file, 'w') as fo:
        for batch in data:
            batch = move_to_device(batch, model.device)
            res = parse_batch(model, batch, beam_size, alpha, max_time_step)
            for concept, relation, score in zip(res['concept'], res['relation'], res['score']):
                fo.write('# ::conc '+ ' '.join(concept)+'\n')
                fo.write('# ::score %.6f\n'%score)
                fo.write(pp.postprocess(concept, relation)+'\n\n')
                tot += 1
    match(output_file, input_file)
    print ('write down %d amrs'%tot)
示例#4
0
def main(local_rank, args):
    vocabs, lexical_mapping = load_vocabs(args)
    bert_encoder = None
    if args.with_bert:
        bert_encoder = BertEncoder.from_pretrained(args.bert_path)
        for p in bert_encoder.parameters():
            p.requires_grad = False

    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    model = Parser(vocabs, args.word_char_dim, args.word_dim, args.pos_dim,
                   args.ner_dim, args.concept_char_dim, args.concept_dim,
                   args.cnn_filters, args.char2word_dim, args.char2concept_dim,
                   args.embed_dim, args.ff_embed_dim, args.num_heads,
                   args.dropout, args.snt_layers, args.graph_layers,
                   args.inference_layers, args.rel_dim, args.pretrained_file,
                   bert_encoder, device)

    if args.world_size > 1:
        torch.manual_seed(19940117 + dist.get_rank())
        torch.cuda.manual_seed_all(19940117 + dist.get_rank())
        random.seed(19940117 + dist.get_rank())

    model = model.cuda(local_rank)
    dev_data = DataLoader(vocabs,
                          lexical_mapping,
                          args.dev_data,
                          args.dev_batch_size,
                          for_train=False)
    pp = PostProcessor(vocabs['rel'])

    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{
        'params': weight_decay_params,
        'weight_decay': 1e-4
    }, {
        'params': no_weight_decay_params,
        'weight_decay': 0.
    }]
    optimizer = AdamWeightDecayOptimizer(grouped_params,
                                         1.,
                                         betas=(0.9, 0.999),
                                         eps=1e-6)

    used_batches = 0
    batches_acm = 0
    if args.resume_ckpt:
        ckpt = torch.load(args.resume_ckpt)
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        batches_acm = ckpt['batches_acm']
        del ckpt

    train_data = DataLoader(vocabs,
                            lexical_mapping,
                            args.train_data,
                            args.train_batch_size,
                            for_train=True)
    train_data.set_unk_rate(args.unk_rate)
    queue = mp.Queue(10)
    train_data_generator = mp.Process(target=data_proc,
                                      args=(train_data, queue))

    train_data_generator.start()
    model.train()
    epoch, loss_avg, concept_loss_avg, arc_loss_avg, rel_loss_avg = 0, 0, 0, 0, 0
    while True:
        batch = queue.get()
        if isinstance(batch, str):
            epoch += 1
            print('epoch', epoch, 'done', 'batches', batches_acm)
        else:
            batch = move_to_device(batch, model.device)
            concept_loss, arc_loss, rel_loss, graph_arc_loss = model(batch)
            loss = (concept_loss + arc_loss +
                    rel_loss) / args.batches_per_update
            loss_value = loss.item()
            concept_loss_value = concept_loss.item()
            arc_loss_value = arc_loss.item()
            rel_loss_value = rel_loss.item()
            loss_avg = loss_avg * args.batches_per_update * 0.8 + 0.2 * loss_value
            concept_loss_avg = concept_loss_avg * 0.8 + 0.2 * concept_loss_value
            arc_loss_avg = arc_loss_avg * 0.8 + 0.2 * arc_loss_value
            rel_loss_avg = rel_loss_avg * 0.8 + 0.2 * rel_loss_value
            loss.backward()
            used_batches += 1
            if not (used_batches % args.batches_per_update
                    == -1 % args.batches_per_update):
                continue

            batches_acm += 1
            if args.world_size > 1:
                average_gradients(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            lr = update_lr(optimizer, args.lr_scale, args.embed_dim,
                           batches_acm, args.warmup_steps)
            optimizer.step()
            optimizer.zero_grad()
            if args.world_size == 1 or (dist.get_rank() == 0):
                if batches_acm % args.print_every == -1 % args.print_every:
                    print(
                        'Train Epoch %d, Batch %d, LR %.6f, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f'
                        % (epoch, batches_acm, lr, concept_loss_avg,
                           arc_loss_avg, rel_loss_avg))
                    model.train()
                if (
                        batches_acm > 10000 or args.resume_ckpt is not None
                ) and batches_acm % args.eval_every == -1 % args.eval_every:
                    model.eval()
                    parse_data(
                        model, pp, dev_data, args.dev_data,
                        '%s/epoch%d_batch%d_dev_out' %
                        (args.ckpt, epoch, batches_acm))
                    torch.save(
                        {
                            'args': args,
                            'model': model.state_dict(),
                            'batches_acm': batches_acm,
                            'optimizer': optimizer.state_dict()
                        },
                        '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm))
                    model.train()