Exemple #1
0
 def prepare_incremental_input(self, step_seq):
     conc = ListsToTensor(step_seq, self.vocabs['concept'])
     conc_char = ListsofStringToTensor(step_seq,
                                       self.vocabs['concept_char'])
     conc, conc_char = move_to_device(conc, self.device), move_to_device(
         conc_char, self.device)
     return conc, conc_char
Exemple #2
0
def train_loop(model, dl, batch_size: int, epoch, epochs, optimizer, verbose,
               max_iter_count, device):
    running_metrics = []
    print(f"running epoch [{epoch+1}/{epochs}]")
    model.train()
    for iter_count, (batch, targets) in enumerate(dl):
        batch, targets = move_to_device(batch, targets, device=device)

        optimizer.zero_grad()
        metrics = model.training_step(batch, targets)
        metrics['loss'].backward()
        optimizer.step()

        metrics['loss'] = metrics['loss'].item()

        running_metrics.append(metrics)
        if (iter_count + 1) % verbose == 0:
            means = caclulate_means(running_metrics)
            running_metrics = []
            log = []
            for k, v in means.items():
                log.append(f"{k}: {v:.04f}")
            log = "\t".join(log)
            log += f"\titer[{iter_count}/{max_iter_count}]"
            print(log)
Exemple #3
0
def record_batch(model, batch, data):
    batch = move_to_device(batch, model.device)
    attn = model.encoder_attn(batch)
    #nlayers x tgt_len x src_len x  bsz x num_heads

    for i, x in enumerate(data):
        L = len(x['concept']) + 1
        x['attn'] = attn[:, :L, :L, i, :].cpu()
    return data
Exemple #4
0
def generate_batch(model, batch, beam_size, alpha, max_time_step):
    batch = move_to_device(batch, model.device)
    res = dict()
    token_batch, score_batch = [], []
    beams = model.work(batch, beam_size, max_time_step)
    for beam in beams:
        best_hyp = beam.get_k_best(1, alpha)[0]
        predicted_token = [token for token in best_hyp.seq[1:-1]]
        token_batch.append(predicted_token)
        score_batch.append(best_hyp.score)
    res['token'] = token_batch
    res['score'] = score_batch
    return res
Exemple #5
0
    def work(self, inp, allow_hit):
        src_tokens = inp['src_tokens']
        src_feat, src, src_mask = self.model(src_tokens, return_src=True)
        num_heads, bsz, dim = src_feat.size()
        assert num_heads == self.num_heads
        topk = self.topk
        vecsq = src_feat.reshape(num_heads * bsz, -1).detach().cpu().numpy() 
        #retrieval_start = time.time()
        vecsq = augment_query(vecsq)
        D, I = self.mips.search(vecsq, topk + 1)
        D = l2_to_ip(D, vecsq, self.mips_max_norm) / (self.mips_max_norm * self.mips_max_norm)
        # I, D: (bsz * num_heads x (topk + 1) )
        indices = torch.zeros(topk, num_heads, bsz, dtype=torch.long)
        for i, (Ii, Di) in enumerate(zip(I, D)):
            bid, hid = i % bsz, i // bsz
            tmp_list = []
            for pred, _ in zip(Ii, Di):
                if allow_hit or self.mem_pool[pred]!=inp['tgt_raw_sents'][bid]:
                    tmp_list.append(pred)
            tmp_list = tmp_list[:topk]
            assert len(tmp_list) == topk
            indices[:, hid, bid] = torch.tensor(tmp_list)
        #retrieval_cost = time.time() - retrieval_start
        #print ('retrieval_cost', retrieval_cost)
        # convert to tensors:
        # all_mem_tokens -> seq_len x ( topk * num_heads * bsz )
        # all_mem_feats -> topk * num_heads * bsz x dim
        all_mem_tokens = []
        for idx in indices.view(-1).tolist():
            #TODO self.mem_pool[idx] +[EOS]
            all_mem_tokens.append([BOS] + self.mem_pool[idx])
        all_mem_tokens = ListsToTensor(all_mem_tokens, self.vocabs['tgt'])
        
        # to avoid GPU OOM issue, truncate the mem to the max. length of 1.5 x src_tokens
        max_mem_len = int(1.5 * src_tokens.shape[0])
        all_mem_tokens = move_to_device(all_mem_tokens[:max_mem_len,:], inp['src_tokens'].device)
       
        if torch.is_tensor(self.mem_feat_or_feat_maker):
            all_mem_feats = self.mem_feat_or_feat_maker[indices].to(src_feat.device)
        else:
            all_mem_feats = self.mem_feat_or_feat_maker(all_mem_tokens).view(topk, num_heads, bsz, dim)

        # all_mem_scores -> topk x num_heads x bsz
        all_mem_scores = torch.sum(src_feat.unsqueeze(0) * all_mem_feats, dim=-1) / (self.mips_max_norm ** 2)

        mem_ret = {}
        indices = indices.view(-1, bsz).transpose(0, 1).tolist()
        mem_ret['retrieval_raw_sents'] = [ [self.mem_pool[idx] for idx in ind] for ind in indices]
        mem_ret['all_mem_tokens'] = all_mem_tokens
        mem_ret['all_mem_scores'] = all_mem_scores
        return src, src_mask, mem_ret
Exemple #6
0
def get_features(batch_size, norm_th, vocab, model, used_data, used_ids, max_norm=None, max_norm_cf=1.0):
    vecs, ids = [], []
    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
    model.eval()
    data_loader = DataLoader(used_data, vocab, batch_size)
    cur, tot = 0, len(used_data)
    for batch in asynchronous_load(data_loader):
        batch = move_to_device(batch, torch.device('cuda', 0)).t()
        bsz = batch.size(0)
        cur_vecs = model(batch, batch_first=True).detach().cpu().numpy()
        valid = np.linalg.norm(cur_vecs, axis=1) <= norm_th
        vecs.append(cur_vecs[valid])
        ids.append(used_ids[cur:cur+batch_size][valid])
        cur += bsz
        logger.info("%d / %d", cur, tot)
    vecs = np.concatenate(vecs, 0)
    ids = np.concatenate(ids, 0)
    out, max_norm = augment_data(vecs, max_norm, max_norm_cf)
    return out, ids, max_norm
Exemple #7
0
def validate(model, dev_data, device):
    model.eval()
    q_list = []
    r_list = []
    for batch in dev_data:
        batch = move_to_device(batch, device)
        q = model.query_encoder(batch['src_tokens'])
        r = model.response_encoder(batch['tgt_tokens'])
        q_list.append(q)
        r_list.append(r)
    q = torch.cat(q_list, dim=0)
    r = torch.cat(r_list, dim=0)

    bsz = q.size(0)
    scores = torch.mm(q, r.t())  # bsz x bsz
    gold = torch.arange(bsz, device=scores.device)
    _, pred = torch.max(scores, -1)
    acc = torch.sum(torch.eq(gold, pred).float()) / bsz
    return acc
Exemple #8
0
def validate(device,
             model,
             test_data,
             beam_size=5,
             alpha=0.6,
             max_time_step=100,
             dump_path=None):
    """For Development Only"""

    ref_stream = []
    sys_stream = []
    topk_sys_retr_stream = []
    for batch in test_data:
        batch = move_to_device(batch, device)
        res, _ = generate_batch(model, batch, beam_size, alpha, max_time_step)
        sys_stream.extend(res)
        ref_stream.extend(batch['tgt_raw_sents'])
        sys_retr = batch.get('retrieval_raw_sents', None)
        if sys_retr:
            topk_sys_retr_stream.extend(sys_retr)

    assert len(sys_stream) == len(ref_stream)

    sys_stream = [
        re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(o)) for o in sys_stream
    ]
    ref_stream = [
        re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(o)) for o in ref_stream
    ]
    ref_streams = [ref_stream]

    bleu = sacrebleu.corpus_bleu(sys_stream,
                                 ref_streams,
                                 force=True,
                                 lowercase=False,
                                 tokenize='none').score
    sys_retr_streams = []
    if topk_sys_retr_stream:
        assert len(topk_sys_retr_stream) == len(ref_stream)
        topk = len(topk_sys_retr_stream[0])
        for i in range(topk):
            sys_retr_stream = [
                re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(o[i]))
                for o in topk_sys_retr_stream
            ]
            lratio = []
            for aa, bb in zip(sys_retr_stream, ref_stream):
                laa = len(aa.split())
                lbb = len(bb.split())
                lratio.append(max(laa / lbb, lbb / laa))
            bleu_retr = sacrebleu.corpus_bleu(sys_retr_stream,
                                              ref_streams,
                                              force=True,
                                              lowercase=False,
                                              tokenize='none').score
            sys_retr_streams.append(sys_retr_stream)
            logger.info("Retrieval top%d bleu %.2f length ratio %.2f", i + 1,
                        bleu_retr,
                        sum(lratio) / len(lratio))
        # logger.info("show some examples >>>")
        # for sample_id in [5, 6, 11, 22, 33, 44, 55, 66, 555, 666]:
        #     retrieval = [ "%d: %s"%(i, sys_retr_streams[i][sample_id]) for i in range(topk)]
        #     logger.info("%d: %s###\n generation: %s###\nretrieval:\n %s", sample_id, ref_stream[sample_id], sys_stream[sample_id], '\n'.join(retrieval))
        # logger.info("<<< show some examples")
    if dump_path is not None:
        results = {
            'sys_stream': sys_stream,
            'ref_stream': ref_stream,
            'sys_retr_streams': sys_retr_streams
        }
        json.dump(results, open(dump_path, 'w'))
    return bleu
Exemple #9
0
                            model,
                            test_data,
                            beam_size=args.beam_size,
                            alpha=args.alpha,
                            max_time_step=args.max_time_step,
                            dump_path=args.dump_path)
            logger.info("%s %s %.2f", test_model, args.test_data, bleu)

        if args.output_path is not None:
            start_time = time.time()
            TOT = len(test_data)
            DONE = 0
            logger.info("%d/%d", DONE, TOT)
            outs, indices = [], []
            for batch in test_data:
                batch = move_to_device(batch, device)
                res, ind = generate_batch(model, batch, args.beam_size,
                                          args.alpha, args.max_time_step)
                for out_tokens, index in zip(res, ind):
                    if args.retain_bpe:
                        out_line = ' '.join(out_tokens)
                    else:
                        out_line = re.sub(r'(@@ )|(@@ ?$)', '',
                                          ' '.join(out_tokens))
                    DONE += 1
                    if DONE % 10000 == -1 % 10000:
                        logger.info("%d/%d", DONE, TOT)
                    outs.append(out_line)
                    indices.append(index)
            end_time = time.time()
            logger.info("Time elapsed: %f", end_time - start_time)
Exemple #10
0
def main(args, local_rank):
    vocabs = dict()
    vocabs['tok'] = Vocab(args.tok_vocab, 5, [CLS])
    vocabs['lem'] = Vocab(args.lem_vocab, 5, [CLS])
    vocabs['pos'] = Vocab(args.pos_vocab, 5, [CLS])
    vocabs['ner'] = Vocab(args.ner_vocab, 5, [CLS])
    vocabs['predictable_concept'] = Vocab(args.predictable_concept_vocab, 10,
                                          [DUM, END])
    vocabs['concept'] = Vocab(args.concept_vocab, 5, [DUM, END])
    vocabs['rel'] = Vocab(args.rel_vocab, 50, [NIL])
    vocabs['word_char'] = Vocab(args.word_char_vocab, 100, [CLS, END])
    vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [CLS, END])
    lexical_mapping = LexicalMap(args.lexical_mapping)
    if args.pretrained_word_embed is not None:
        vocab, pretrained_embs = load_pretrained_word_embed(
            args.pretrained_word_embed)
        vocabs['glove'] = vocab
    else:
        pretrained_embs = None

    for name in vocabs:
        print((name, vocabs[name].size))

    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)
    device = torch.device('cuda', local_rank)
    #print(device)
    #exit()
    model = Parser(vocabs,
                   args.word_char_dim,
                   args.word_dim,
                   args.pos_dim,
                   args.ner_dim,
                   args.concept_char_dim,
                   args.concept_dim,
                   args.cnn_filters,
                   args.char2word_dim,
                   args.char2concept_dim,
                   args.embed_dim,
                   args.ff_embed_dim,
                   args.num_heads,
                   args.dropout,
                   args.snt_layers,
                   args.graph_layers,
                   args.inference_layers,
                   args.rel_dim,
                   pretrained_embs,
                   device=device)

    if args.world_size > 1:
        torch.manual_seed(19940117 + dist.get_rank())
        torch.cuda.manual_seed_all(19940117 + dist.get_rank())
        random.seed(19940117 + dist.get_rank())

    model = model.cuda(local_rank)
    train_data = DataLoader(vocabs,
                            lexical_mapping,
                            args.train_data,
                            args.train_batch_size,
                            for_train=True)
    dev_data = DataLoader(vocabs,
                          lexical_mapping,
                          args.dev_data,
                          args.dev_batch_size,
                          for_train=True)
    train_data.set_unk_rate(args.unk_rate)

    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{
        'params': weight_decay_params,
        'weight_decay': 1e-4
    }, {
        'params': no_weight_decay_params,
        'weight_decay': 0.
    }]
    optimizer = AdamWeightDecayOptimizer(grouped_params,
                                         lr=args.lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-6)

    batches_acm, loss_acm, concept_loss_acm, arc_loss_acm, rel_loss_acm = 0, 0, 0, 0, 0
    #model.load_state_dict(torch.load('./ckpt/epoch297_batch49999')['model'])
    discarded_batches_acm = 0
    queue = mp.Queue(10)
    train_data_generator = mp.Process(target=data_proc,
                                      args=(train_data, queue))
    train_data_generator.start()

    used_batches = 0
    if args.resume_ckpt:
        ckpt = torch.load(args.resume_ckpt)
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        batches_acm = ckpt['batches_acm']
        del ckpt

    model.train()
    epoch = 0
    while True:
        batch = queue.get()
        #print("epoch",epoch)
        #print("batches_acm",batches_acm)
        #print("used_batches",used_batches)
        if isinstance(batch, str):
            epoch += 1
            print('epoch', epoch, 'done', 'batches', batches_acm)
        else:
            batch = move_to_device(batch, model.device)
            concept_loss, arc_loss, rel_loss = model(batch)
            loss = (concept_loss + arc_loss +
                    rel_loss) / args.batches_per_update
            loss_value = loss.item()
            concept_loss_value = concept_loss.item()
            arc_loss_value = arc_loss.item()
            rel_loss_value = rel_loss.item()
            if batches_acm > args.warmup_steps and arc_loss_value > 5. * (
                    arc_loss_acm / batches_acm):
                discarded_batches_acm += 1
                print('abnormal', concept_loss.item(), arc_loss.item(),
                      rel_loss.item())
                continue
            loss_acm += loss_value
            concept_loss_acm += concept_loss_value
            arc_loss_acm += arc_loss_value
            rel_loss_acm += rel_loss_value
            loss.backward()

            used_batches += 1
            if not (used_batches % args.batches_per_update
                    == -1 % args.batches_per_update):
                continue
            batches_acm += 1

            if args.world_size > 1:
                average_gradients(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            update_lr(optimizer, args.embed_dim, batches_acm,
                      args.warmup_steps)
            optimizer.step()
            optimizer.zero_grad()
            if args.world_size == 1 or (dist.get_rank() == 0):
                if batches_acm % args.print_every == -1 % args.print_every:
                    print(
                        'Train Epoch %d, Batch %d, Discarded Batch %d, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f'
                        % (epoch, batches_acm, discarded_batches_acm,
                           concept_loss_acm / batches_acm, arc_loss_acm /
                           batches_acm, rel_loss_acm / batches_acm))
                    model.train()

                if batches_acm % args.eval_every == -1 % args.eval_every:
                    model.eval()
                    torch.save(
                        {
                            'args': args,
                            'model': model.state_dict(),
                            'batches_acm': batches_acm,
                            'optimizer': optimizer.state_dict()
                        },
                        '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm))
                    model.train()
Exemple #11
0
    vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS])
    vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END])
    vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END])
    vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END])
    vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END])
    vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL])
    lexical_mapping = LexicalMap()

    train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True)
    epoch_idx = 0
    batch_idx = 0
    last = 0
    while True:
        st = time.time()
        for d in train_data:
            d = move_to_device(d, torch.device('cpu'))
            batch_idx += 1
            #if d['concept'].size(0) > 5:
            #    continue
            print (epoch_idx, batch_idx, d['concept'].size(), d['token_in'].size())
            c_len, bsz = d['concept'].size()
            t_len, bsz = d['token_in'].size()
            print (bsz, c_len*bsz, t_len * bsz) 
            #print (d['relation_bank'].size())
            #print (d['relation'].size())

            #_back_to_txt_for_check(d['concept'], vocabs['concept'])
            #for x in d['concept_depth'].t().tolist():
            #    print (x)
            #_back_to_txt_for_check(d['token_in'], vocabs['token'])
            #_back_to_txt_for_check(d['token_out'], vocabs['predictable_token'], d['local_idx2token'])
Exemple #12
0
def main(args, local_rank):
    vocabs = dict()
    vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS])
    vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END])
    vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END])
    vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END])
    vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END])
    vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL])
    lexical_mapping = LexicalMap()

    for name in vocabs:
        print((name, vocabs[name].size, vocabs[name].coverage))

    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)

    #device = torch.device('cpu')
    device = torch.device('cuda', local_rank)
    model = Generator(vocabs, args.token_char_dim, args.token_dim,
                      args.concept_char_dim, args.concept_dim,
                      args.cnn_filters, args.char2word_dim,
                      args.char2concept_dim, args.rel_dim,
                      args.rnn_hidden_size, args.rnn_num_layers,
                      args.embed_dim, args.ff_embed_dim, args.num_heads,
                      args.dropout, args.snt_layers, args.graph_layers,
                      args.inference_layers, args.pretrained_file, device)

    if args.world_size > 1:
        torch.manual_seed(19940117 + dist.get_rank())
        torch.cuda.manual_seed_all(19940117 + dist.get_rank())
        random.seed(19940117 + dist.get_rank())

    model = model.to(device)
    train_data = DataLoader(vocabs,
                            lexical_mapping,
                            args.train_data,
                            args.train_batch_size,
                            for_train=True)
    #dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=False)
    train_data.set_unk_rate(args.unk_rate)

    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{
        'params': weight_decay_params,
        'weight_decay': 1e-4
    }, {
        'params': no_weight_decay_params,
        'weight_decay': 0.
    }]
    optimizer = AdamWeightDecayOptimizer(grouped_params,
                                         lr=args.lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-6)

    batches_acm, loss_acm = 0, 0
    discarded_batches_acm = 0

    queue = mp.Queue(10)
    train_data_generator = mp.Process(target=data_proc,
                                      args=(train_data, queue))
    train_data_generator.start()

    model.train()
    epoch = 0
    while batches_acm < args.total_train_steps:
        batch = queue.get()
        if isinstance(batch, str):
            epoch += 1
            print('epoch', epoch, 'done', 'batches', batches_acm)
            continue
        batch = move_to_device(batch, device)
        loss = model(batch)
        exit(0)

        loss_value = loss.item()
        if batches_acm > args.warmup_steps and loss_value > 5. * (loss_acm /
                                                                  batches_acm):
            discarded_batches_acm += 1
            print('abnormal', loss_value)
            continue
        loss_acm += loss_value
        batches_acm += 1
        loss.backward()
        if args.world_size > 1:
            average_gradients(model)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        update_lr(optimizer, args.embed_dim, batches_acm, args.warmup_steps)
        optimizer.step()
        optimizer.zero_grad()
        #------------
        if args.world_size == 1 or (dist.get_rank() == 0):
            if batches_acm % args.print_every == -1 % args.print_every:
                print(
                    'Train Epoch %d, Batch %d, Discarded Batch %d, loss %.3f' %
                    (epoch, batches_acm, discarded_batches_acm,
                     loss_acm / batches_acm))
                model.train()
            if batches_acm > args.warmup_steps and batches_acm % args.eval_every == -1 % args.eval_every:
                #model.eval()
                #bleu, chrf = validate(model, dev_data)
                #print ("epoch", "batch", "bleu", "chrf")
                #print (epoch, batches_acm, bleu, chrf)
                torch.save({
                    'args': args,
                    'model': model.state_dict()
                }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm))
                model.train()
Exemple #13
0
def main(args, local_rank):

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    vocabs = dict()
    vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS])
    vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS])

    if args.world_size == 1 or (dist.get_rank() == 0):
        logger.info(args)
        for name in vocabs:
            logger.info("vocab %s, size %d, coverage %.3f", name,
                        vocabs[name].size, vocabs[name].coverage)

    set_seed(19940117)

    #device = torch.device('cpu')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    if args.arch == 'vanilla':
        model = Generator(vocabs, args.embed_dim, args.ff_embed_dim,
                          args.num_heads, args.dropout, args.enc_layers,
                          args.dec_layers, args.label_smoothing)
    elif args.arch == 'mem':
        model = MemGenerator(vocabs, args.embed_dim, args.ff_embed_dim,
                             args.num_heads, args.dropout, args.mem_dropout,
                             args.enc_layers, args.dec_layers,
                             args.mem_enc_layers, args.label_smoothing,
                             args.use_mem_score)
    elif args.arch == 'rg':
        logger.info("start building model")
        logger.info("building retriever")
        retriever = Retriever.from_pretrained(
            args.num_retriever_heads,
            vocabs,
            args.retriever,
            args.nprobe,
            args.topk,
            local_rank,
            use_response_encoder=(args.rebuild_every > 0))

        logger.info("building retriever + generator")
        model = RetrieverGenerator(vocabs, retriever, args.share_encoder,
                                   args.embed_dim, args.ff_embed_dim,
                                   args.num_heads, args.dropout,
                                   args.mem_dropout, args.enc_layers,
                                   args.dec_layers, args.mem_enc_layers,
                                   args.label_smoothing)

    if args.resume_ckpt:
        model.load_state_dict(torch.load(args.resume_ckpt)['model'])
    else:
        global_step = 0

    if args.world_size > 1:
        set_seed(19940117 + dist.get_rank())

    model = model.to(device)

    retriever_params = [
        v for k, v in model.named_parameters() if k.startswith('retriever.')
    ]
    other_params = [
        v for k, v in model.named_parameters()
        if not k.startswith('retriever.')
    ]

    optimizer = Adam([{
        'params': retriever_params,
        'lr': args.embed_dim**-0.5 * 0.1
    }, {
        'params': other_params,
        'lr': args.embed_dim**-0.5
    }],
                     betas=(0.9, 0.98),
                     eps=1e-9)
    lr_schedule = get_inverse_sqrt_schedule_with_warmup(
        optimizer, args.warmup_steps, args.total_train_steps)
    train_data = DataLoader(vocabs,
                            args.train_data,
                            args.per_gpu_train_batch_size,
                            for_train=True,
                            rank=local_rank,
                            num_replica=args.world_size)

    model.eval()
    #dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False)
    #bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=10)

    step, epoch = 0, 0
    tr_stat = Statistics()
    logger.info("start training")
    model.train()

    best_dev_bleu = 0.
    while global_step <= args.total_train_steps:
        for batch in train_data:
            #step_start = time.time()
            batch = move_to_device(batch, device)
            if args.arch == 'rg':
                loss, acc = model(
                    batch,
                    update_mem_bias=(global_step >
                                     args.update_retriever_after))
            else:
                loss, acc = model(batch)

            tr_stat.update({
                'loss': loss.item() * batch['tgt_num_tokens'],
                'tokens': batch['tgt_num_tokens'],
                'acc': acc
            })
            tr_stat.step()
            loss.backward()
            #step_cost = time.time() - step_start
            #print ('step_cost', step_cost)
            step += 1
            if not (step % args.gradient_accumulation_steps
                    == -1 % args.gradient_accumulation_steps):
                continue

            if args.world_size > 1:
                average_gradients(model)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_schedule.step()
            optimizer.zero_grad()
            global_step += 1

            if args.world_size == 1 or (dist.get_rank() == 0):
                if global_step % args.print_every == -1 % args.print_every:
                    logger.info("epoch %d, step %d, loss %.3f, acc %.3f",
                                epoch, global_step,
                                tr_stat['loss'] / tr_stat['tokens'],
                                tr_stat['acc'] / tr_stat['tokens'])
                    tr_stat = Statistics()
                if global_step % args.eval_every == -1 % args.eval_every:
                    model.eval()
                    max_time_step = 256 if global_step > 2 * args.warmup_steps else 5
                    bleus = []
                    for cur_dev_data in args.dev_data:
                        dev_data = DataLoader(vocabs,
                                              cur_dev_data,
                                              args.dev_batch_size,
                                              for_train=False)
                        bleu = validate(device,
                                        model,
                                        dev_data,
                                        beam_size=5,
                                        alpha=0.6,
                                        max_time_step=max_time_step)
                        bleus.append(bleu)
                    bleu = sum(bleus) / len(bleus)
                    logger.info("epoch %d, step %d, dev bleu %.2f", epoch,
                                global_step, bleu)
                    if bleu > best_dev_bleu:
                        testbleus = []
                        for cur_test_data in args.test_data:
                            test_data = DataLoader(vocabs,
                                                   cur_test_data,
                                                   args.dev_batch_size,
                                                   for_train=False)
                            testbleu = validate(device,
                                                model,
                                                test_data,
                                                beam_size=5,
                                                alpha=0.6,
                                                max_time_step=max_time_step)
                            testbleus.append(testbleu)
                        testbleu = sum(testbleus) / len(testbleus)
                        logger.info("epoch %d, step %d, test bleu %.2f", epoch,
                                    global_step, testbleu)
                        torch.save({
                            'args': args,
                            'model': model.state_dict()
                        }, '%s/best.pt' % (args.ckpt, ))
                        if not args.only_save_best:
                            torch.save(
                                {
                                    'args': args,
                                    'model': model.state_dict()
                                },
                                '%s/epoch%d_batch%d_devbleu%.2f_testbleu%.2f' %
                                (args.ckpt, epoch, global_step, bleu,
                                 testbleu))
                        best_dev_bleu = bleu
                    model.train()

            if args.rebuild_every > 0 and (global_step % args.rebuild_every
                                           == -1 % args.rebuild_every):
                model.retriever.drop_index()
                torch.cuda.empty_cache()
                next_index_dir = '%s/batch%d' % (args.ckpt, global_step)
                if args.world_size == 1 or (dist.get_rank() == 0):
                    model.retriever.rebuild_index(next_index_dir)
                    dist.barrier()
                else:
                    dist.barrier()
                model.retriever.update_index(next_index_dir, args.nprobe)

            if global_step > args.total_train_steps:
                break
        epoch += 1
    logger.info('rank %d, finish training after %d steps', local_rank,
                global_step)
Exemple #14
0
def validation_loop(model, dl, batch_size: int, epoch: int, device: str):
    # start validation
    total_val_iter = int(len(dl.dataset) / batch_size)
    model.eval()
    print("running validation...")
    all_detections = []
    all_losses = []
    for batch, targets in tqdm(dl, total=total_val_iter):
        batch, targets = move_to_device(batch, targets, device=device)

        detections, losses = model.validation_step(batch, targets)

        all_losses.append(losses)
        all_detections.append(detections)

    # evalute RPN
    iou_thresholds = torch.arange(0.5, 1.0, 0.05)
    rpn_predictions = []
    rpn_ground_truths = []
    for dets in all_detections:
        rpn_predictions += dets['rpn']['predictions']
        rpn_ground_truths += dets['rpn']['ground_truths']

    rpn_recalls = roi_recalls(rpn_predictions,
                              rpn_ground_truths,
                              iou_thresholds=iou_thresholds)

    # evalute FastRCNN
    head_predictions = []
    head_ground_truths = []
    for dets in all_detections:
        head_predictions += dets['head']['predictions']
        head_ground_truths += dets['head']['ground_truths']
    head_predictions = [pred[:, :5] for pred in head_predictions]
    head_ground_truths = [pred[:, :4] for pred in head_ground_truths]

    AP50 = calculate_AP(head_predictions,
                        head_ground_truths,
                        iou_threshold=0.5)
    AP75 = calculate_AP(head_predictions,
                        head_ground_truths,
                        iou_threshold=0.75)
    AP90 = calculate_AP(head_predictions,
                        head_ground_truths,
                        iou_threshold=0.90)
    AP = (AP50 + AP75 + AP90) / 3
    means = caclulate_means(all_losses)

    print(f"--validation results for epoch {epoch+1} --")
    print(f"RPN mean recall at iou thresholds are:")
    for iou_threshold, rpn_recall in zip(iou_thresholds.cpu().numpy(),
                                         rpn_recalls.cpu().numpy() * 100):
        print(f"IoU={iou_threshold:.02f} recall={int(rpn_recall)}")
    print(f"HEAD AP IoU=.5 :{AP50.item()*100:.02f}")
    print(f"HEAD AP IoU=.75 :{AP75.item()*100:.02f}")
    print(f"HEAD AP IoU=.90 :{AP90.item()*100:.02f}")
    print(f"HEAD AP IoU=.5:.95 :{AP.item()*100:.02f}")

    for k, v in means.items():
        print(f"{k}: {v:.4f}")
    print("--------------------------------------------")
Exemple #15
0
 def prepare_incremental_input(self, step_seq):
     token = ListsToTensor(step_seq, self.vocabs['token'])
     token_char = ListsofStringToTensor(step_seq, self.vocabs['token_char'])
     token, token_char = move_to_device(token, self.device), move_to_device(
         token_char, self.device)
     return token, token_char
Exemple #16
0
def main(args):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger.info('Loading model...')
    device = torch.device('cuda', 0)

    vocab = Vocab(args.vocab_path, 0, [BOS, EOS])
    model_args = torch.load(args.args_path)
    model = ProjEncoder.from_pretrained(vocab, model_args, args.ckpt_path)
    model.to(device)

    logger.info('Collecting data...')

    data_r = []
    with open(args.index_file) as f:
        for line in f.readlines():
            r = line.strip()
            data_r.append(r)

    data_q = []
    data_qr = []
    with open(args.input_file, 'r') as f:
        for line in f.readlines():
            q, r = line.strip().split('\t')
            data_q.append(q)
            data_qr.append(r)

    logger.info('Collected %d instances', len(data_q))
    textq, textqr, textr = data_q, data_qr, data_r
    data_loader = DataLoader(data_q, vocab, args.batch_size)

    mips = MIPS.from_built(args.index_path, nprobe=args.nprobe)
    max_norm = torch.load(os.path.dirname(args.index_path) + '/max_norm.pt')
    mips.to_gpu()
    model.cuda()
    model = torch.nn.DataParallel(model,
                                  device_ids=list(
                                      range(torch.cuda.device_count())))
    model.eval()

    logger.info('Start search')
    cur, tot = 0, len(data_q)
    with open(args.output_file, 'w') as fo:
        for batch in asynchronous_load(data_loader):
            with torch.no_grad():
                q = move_to_device(batch, torch.device('cuda')).t()
                bsz = q.size(0)
                vecsq = model(q, batch_first=True).detach().cpu().numpy()
            vecsq = augment_query(vecsq)
            D, I = mips.search(vecsq, args.topk + 1)
            D = l2_to_ip(D, vecsq, max_norm) / (max_norm * max_norm)
            for i, (Ii, Di) in enumerate(zip(I, D)):
                item = [textq[cur + i], textqr[cur + i]]
                for pred, s in zip(Ii, Di):
                    if args.allow_hit or textr[pred] != textqr[cur + i]:
                        item.append(textr[pred])
                        item.append(str(float(s)))
                item = item[:2 + 2 * args.topk]
                assert len(item) == 2 + 2 * args.topk
                fo.write('\t'.join(item) + '\n')
            cur += bsz
            logger.info('finished %d / %d', cur, tot)
Exemple #17
0
def main(args, local_rank):

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    vocabs = dict()
    vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS])
    vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS])

    if args.world_size == 1 or (dist.get_rank() == 0):
        logger.info(args)
        for name in vocabs:
            logger.info("vocab %s, size %d, coverage %.3f", name,
                        vocabs[name].size, vocabs[name].coverage)

    set_seed(19940117)

    #device = torch.device('cpu')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    if args.resume_ckpt:
        model = MatchingModel.from_pretrained(vocabs, args.resume_ckpt)
    else:
        model = MatchingModel.from_params(vocabs, args.layers, args.embed_dim,
                                          args.ff_embed_dim, args.num_heads,
                                          args.dropout, args.output_dim,
                                          args.bow)

    if args.world_size > 1:
        set_seed(19940117 + dist.get_rank())

    model = model.to(device)

    if args.resume_ckpt:
        dev_data = DataLoader(vocabs,
                              args.dev_data,
                              args.dev_batch_size,
                              addition=args.additional_negs)
        acc = validate(model, dev_data, device)
        logger.info("initialize from %s, initial acc %.2f", args.resume_ckpt,
                    acc)

    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     betas=(0.9, 0.98),
                     eps=1e-9)
    lr_schedule = get_linear_schedule_with_warmup(optimizer, args.warmup_steps,
                                                  args.total_train_steps)
    train_data = DataLoader(vocabs,
                            args.train_data,
                            args.per_gpu_train_batch_size,
                            worddrop=args.worddrop,
                            addition=args.additional_negs)
    global_step, step, epoch = 0, 0, 0
    tr_stat = Statistics()
    logger.info("start training")
    model.train()
    while global_step <= args.total_train_steps:
        for batch in train_data:
            batch = move_to_device(batch, device)
            loss, acc, bsz = model(batch['src_tokens'], batch['tgt_tokens'],
                                   args.label_smoothing)
            tr_stat.update({
                'loss': loss.item() * bsz,
                'nsamples': bsz,
                'acc': acc * bsz
            })
            tr_stat.step()
            loss.backward()

            step += 1
            if not (step % args.gradient_accumulation_steps
                    == -1 % args.gradient_accumulation_steps):
                continue

            if args.world_size > 1:
                average_gradients(model)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_schedule.step()
            optimizer.zero_grad()
            global_step += 1

            if args.world_size == 1 or (dist.get_rank() == 0):
                if global_step % args.print_every == -1 % args.print_every:
                    logger.info("epoch %d, step %d, loss %.3f, acc %.3f",
                                epoch, global_step,
                                tr_stat['loss'] / tr_stat['nsamples'],
                                tr_stat['acc'] / tr_stat['nsamples'])
                    tr_stat = Statistics()
                if global_step > args.warmup_steps and global_step % args.eval_every == -1 % args.eval_every:
                    dev_data = DataLoader(vocabs,
                                          args.dev_data,
                                          args.dev_batch_size,
                                          addition=args.additional_negs)
                    acc = validate(model, dev_data, device)
                    logger.info("epoch %d, step %d, dev, dev acc %.2f", epoch,
                                global_step, acc)
                    save_path = '%s/epoch%d_batch%d_acc%.2f' % (
                        args.ckpt, epoch, global_step, acc)
                    model.save(args, save_path)
                    model.train()
            if global_step > args.total_train_steps:
                break
        epoch += 1
    logger.info('rank %d, finish training after %d steps', local_rank,
                global_step)