Exemplo n.º 1
0
    def run(self):
        set_seed(self.seed)
        self.model = create_a3c_model(self.cfg)
        self.model.train()
        self.env.set_gpu_id(self.gpu_id)
        self.env.set_model(self.model)

        init_approx = self.gstep.value // (self.cfg.num_workers - 1)
        total_approx = self.cfg.num_episodes // (self.cfg.num_workers - 1)
        with tqdm(desc=self.name,
                  initial=init_approx,
                  total=total_approx,
                  position=self.worker_id) as pbar:
            while not self.done.value:
                self.sync_model()
                self.model.cuda(self.gpu_id)
                self.env.reset()
                self.run_episode()
                pbar.update()
Exemplo n.º 2
0
    def run(self):
        self.init()
        set_seed(self.seed)
        self.model = create_a3c_model(self.cfg, self.vocab, self.stats)
        self.model.cuda(self.gpu_id)
        self.model.eval()
        self.env.set_model(self.model)
        self.env.set_gpu_id(self.gpu_id)

        while not self.done.value:
            self.sync_model()
            step = self.gstep.value
            model_state = self.model.state_dict()
            optim_state = self.optim.state_dict()
            status = dict(acc=0.0,
                          solv=0.0,
                          loss=0.0,
                          qa_loss=0.0,
                          rl_loss=0.0,
                          policy_loss=0.0,
                          entropy=0.0,
                          value_loss=0.0)
            for idx in tqdm(range(self.num_episodes),
                            desc=self.name,
                            position=self.worker_id):
                self.env.reset(idx)
                result = self.run_episode()
                for k, v in result.items():
                    status[k] += v
            for k, v in status.items():
                status[k] /= self.num_episodes

            self._update_tensorboard(step, **status)

            self._save_checkpoint(step=step,
                                  acc=status['acc'],
                                  model_state=model_state,
                                  optim_state=optim_state)
Exemplo n.º 3
0
    def run(self):
        self.init()
        set_seed(self.seed)
        self.model = create_a3c_model(self.cfg)
        self.model.cuda(self.gpu_id)
        self.model.eval()
        self.env.set_model(self.model)
        self.env.set_gpu_id(self.gpu_id)

        while not self.done.value:
            self.sync_model()
            step = self.gstep.value
            model_state = self.model.state_dict()
            optim_state = self.optim.state_dict()
            self.id_list = self.env.dataset.example_ids
            self.qa_list = list(
                set([
                    '_'.join(doc_id.split('_')[:-1]) for doc_id in self.id_list
                ]))
            self.answers = dict()
            for qa_id in self.qa_list:
                self.answers[qa_id] = ('', -100000000)
            status = dict(exact=0.0,
                          f1=0.0,
                          acc=0.0,
                          solv=0.0,
                          loss=0.0,
                          qa_loss=0.0,
                          rl_loss=0.0,
                          policy_loss=0.0,
                          entropy=0.0,
                          value_loss=0.0,
                          reward=0.0,
                          score=[],
                          answer=[])
            for i in tqdm(range(self.num_episodes),
                          desc=self.name,
                          position=self.worker_id):
                self.env.reset(i)
                result = self.run_episode()
                for k, v in result.items():
                    if k == 'answer':
                        status[k].append(v)
                    elif k == 'score':
                        status[k].append(v)
                    else:
                        status[k] += v
            for k, v in status.items():
                if k == 'answer' or k == 'score':
                    continue
                else:
                    status[k] /= self.num_episodes

            for i in range(self.num_episodes):
                qa_id = '_'.join(self.id_list[i].split('_')[:-1])
                score = status['score'][i]
                answer = status['answer'][i]
                if self.answers[qa_id][1] < score:
                    self.answers[qa_id] = (answer, score)

            for qa_id in self.answers.keys():
                self.answers[qa_id] = self.answers[qa_id][0]

            with open(self.cfg.prediction_file, 'w', encoding='utf-8') as f:
                print(json.dumps(self.answers), file=f)
            results = get_score_from_trivia(self.cfg, self.cfg.valid_set)
            exact = results['exact_match']
            f1 = results['f1']
            status['exact'] = exact
            status['f1'] = f1

            self._update_tensorboard(step, **status)

            self._save_checkpoint(step=step,
                                  exact=status['exact'],
                                  f1=status['f1'],
                                  model_state=model_state,
                                  optim_state=optim_state)

            self.answers = dict()
            for qa_id in self.qa_list:
                self.answers[qa_id] = ('', -100000000)
Exemplo n.º 4
0
def demo(cfg):
    if not os.path.exists(cfg.ckpt):
        print('Invalid ckpt path:', cfg.ckpt)
        exit(1)
    ckpt = torch.load(cfg.ckpt, map_location=lambda storage, loc: storage)
    print(cfg.ckpt, 'loaded')
    loaded_cfg = ckpt['cfg'].__dict__

    del loaded_cfg['test_set']
    del loaded_cfg['use_pretrain']
    del loaded_cfg['num_workers']
    del loaded_cfg['num_episodes']
    del loaded_cfg['memory_num']
    del loaded_cfg['memory_len']

    cfg.__dict__.update(loaded_cfg)
    cfg.model = cfg.model.upper()
    pprint(cfg.__dict__)

    model = create_a3c_model(cfg)
    model.load_state_dict(ckpt['model'])
    model.cuda()

    tokenizer = BertTokenizer.from_pretrained(cfg.bert_model)

    env = Environment(cfg, cfg.test_set, tokenizer, shuffle=True)
    env.set_model(model)
    env.set_gpu_id(torch.cuda.current_device())
    print(env.dataset.path, 'loaded')
    while True:
        model.eval()
        env.reset()
        print('-' * 80)
        print('Data ID:', env.data_idx)
        print()
        print('[Context]')
        print(' '.join(tokenizer.convert_ids_to_tokens(env.data.ctx_words)))
        print()

        ques = ' '.join(tokenizer.convert_ids_to_tokens(env.data.ques_words))
        ques = ques.replace(' ##', '')
        ques = ques.replace('##', '')
        print('[Question]')
        print(ques)
        print()
        answs = []
        indices = list(set(env.data.indices))
        for indice in indices:
            s_idx = indice[0]
            e_idx = indice[1]
            answ = ' '.join(
                tokenizer.convert_ids_to_tokens(
                    env.data.ctx_words[s_idx:e_idx + 1]))
            answ = answ.replace(' ##', '')
            answ = answ.replace('##', '')
            answs.append(answ)
        print('[Answer]')
        for i in range(len(answs)):
            print('%d.' % (i + 1), answs[i])

        input('\nPress enter to continue\n')
        while not env.is_done():
            if len(env.memory) < cfg.memory_num - 1:
                env._append_current()
                env.sent_ptr += 1
            else:
                if cfg.model == 'LIFO':
                    break
                env._append_current()
                env.sent_ptr += 1

                batch, solvable, mem_solvable = env.observe()
                batch = {k: v.cuda() for k, v in batch.items()}

                result = model.mem_forward(**batch)
                logit, value = result['logit'], result['value']

                prob = F.softmax(logit, 1)
                _, action = prob.max(1, keepdim=True)

                env.step(action=action.item(), **result)

                _print_mem_result(cfg, tokenizer, batch, prob, action, result,
                                  solvable, mem_solvable, answs)
                # input()

        assert (len(env.memory) <= cfg.memory_num)
        result = _qa_forward(env, model)
        batch = result['batch']
        _print_qa_result(cfg, tokenizer, batch, result, answs, env)
        input()
Exemplo n.º 5
0
def test(cfg):
    set_seed(cfg.seed)

    if not os.path.exists(cfg.ckpt):
        print('Invalid ckpt path:', cfg.ckpt)
        exit(1)
    ckpt = torch.load(cfg.ckpt, map_location=lambda storage, loc: storage)

    print(cfg.ckpt, 'loaded')
    loaded_cfg = ckpt['cfg'].__dict__
    if loaded_cfg.get('num_workers') is not None:
        del loaded_cfg['num_workers']
        del loaded_cfg['test_set']
        del loaded_cfg['pre_ckpt']

    cfg.__dict__.update(loaded_cfg)
    cfg.model = cfg.model.upper()
    pprint(cfg.__dict__)

    prepro_dir = os.path.join(cfg.prepro_dir, 'task%s' % (cfg.task_id))
    with open(os.path.join(prepro_dir, 'vocab.pk'), 'rb') as f:
        vocab = pickle.load(f)
        print()
        print(f.name, 'loaded')

    with open(os.path.join(prepro_dir, 'stats.pk'), 'rb') as f:
        stats = pickle.load(f)
        print(f.name, 'loaded')
        stats['max_ques_len'] = stats['max_sent_len']

    model = create_a3c_model(cfg, vocab, stats)
    model.load_state_dict(ckpt['model'])

    env = create_env(cfg, cfg.test_set, vocab, stats, shuffle=False)
    env.set_model(model)

    print(env.dataset.path, 'loaded')

    queue = mp.Queue()

    procs = []
    for i in range(cfg.num_workers):
        p = TestWorker(cfg, i, model, env, queue)
        p.start()
        procs.append(p)

    num_examples = len(env.dataset)
    dataset = dict()
    for _ in range(num_examples):
        example = queue.get()
        for key, val in example.items():
            dataset[key] = dataset.get(key, 0) + val

    for p in procs:
        p.join()

    from time import sleep
    sleep(3)
    print(f'All processes is finished ({num_examples} examples).')

    print()
    acc = dataset['acc'] / num_examples * 100
    solv = dataset['solv'] / num_examples * 100
    total_error = 100 - acc
    print('[Total]')
    print(f'Acc (Solv): {acc:.2f} ({solv:.2f})')
    print(f'Error : {total_error:.2f}')
    print(f'{total_error:.3f}')

    print()
    print(cfg.log_dir)

    return total_error
Exemplo n.º 6
0
def test(cfg):
    if cfg.ckpt is not None:
        if not os.path.exists(cfg.ckpt):
            print('Invalid ckpt path:', cfg.ckpt)
            exit(1)
        ckpt = torch.load(cfg.ckpt, map_location=lambda storage, loc: storage)
        print(cfg.ckpt, 'loaded')
        loaded_cfg = ckpt['cfg'].__dict__

        del loaded_cfg['num_workers']
        del loaded_cfg['test_set']
        del loaded_cfg['log_dir']
        del loaded_cfg['prediction_file']
        del loaded_cfg['num_episodes']
        del loaded_cfg['use_pretrain']
        del loaded_cfg['memory_num']
        del loaded_cfg['memory_len']
        del loaded_cfg['prepro_dir']
        del loaded_cfg['debug']

        cfg.__dict__.update(loaded_cfg)
        cfg.model = cfg.model.upper()

        print('Merged Config')
        pprint(cfg.__dict__)

        os.makedirs(cfg.log_dir)

        model = create_a3c_model(cfg)
        model.load_state_dict(ckpt['model'])
    else:
        os.makedirs(cfg.log_dir)
        model = create_a3c_model(cfg)

        print(
            "LOAD pretrain parameter for BERT from ./pretrain/pytorch_model.bin..."
        )
        pretrain_param = torch.load('./pretrain/pytorch_model.bin',
                                    map_location=lambda storage, loc: storage)
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        new_pretrain_param = pretrain_param.copy()
        for k, v in pretrain_param.items():
            new_key = 'model.' + k
            new_pretrain_param[new_key] = v
            del new_pretrain_param[k]
        pretrain_param = new_pretrain_param.copy()

        metadata = getattr(pretrain_param, '_metadata', None)
        if metadata is not None:
            pretrain_param._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(
                prefix[:-1], {})
            module._load_from_state_dict(pretrain_param, prefix,
                                         local_metadata, True, missing_keys,
                                         unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')

        load(model, prefix='')
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))

    tokenizer = BertTokenizer.from_pretrained(cfg.bert_model)

    env = Environment(cfg, cfg.test_set, tokenizer, shuffle=False)
    print(env.dataset.path, 'loaded')

    queue = mp.Queue()

    procs = []
    for i in range(cfg.num_workers):
        p = TestWorker(cfg, i, model, env, queue, tokenizer)
        if cfg.debug:
            p.run()
        else:
            p.start()
        procs.append(p)

    results = []
    for p in procs:
        while True:
            running = p.is_alive()
            if not queue.empty():
                result = queue.get()
                results.append(result)
            else:
                if not running:
                    break

    for p in procs:
        p.join()

    exact_list = []
    f1_list = []
    full_action = [0 for _ in range(cfg.memory_num)]
    full_solvable = []
    id_list = []
    for i in range(len(results)):
        id_list.append(results[i]['doc'])
        full_solvable.append(results[i]['solvable'])
        exact_list.append(results[i]['exact'])
        f1_list.append(results[i]['f1'])
        for j in range(cfg.memory_num):
            full_action[j] += results[i]['actions'][j]
    qa_list = list(
        set(['_'.join(doc_id.split('_')[:-1]) for doc_id in id_list]))
    answers = dict()
    for qa_id in qa_list:
        answers[qa_id] = ('', -100000000)

    for i in range(len(results)):
        qa_id = '_'.join(id_list[i].split('_')[:-1])
        score = results[i]['score']
        answer = results[i]['answer']

        if answers[qa_id][1] < score:
            answers[qa_id] = (answer, score)

    for qa_id in answers.keys():
        answers[qa_id] = answers[qa_id][0]

    key_list = list(set(answers.keys()))
    solvables = [[] for i in range(len(key_list))]
    for i in range(len(full_solvable)):
        id_ = '_'.join(id_list[i].split('_')[:-1])
        solv = full_solvable[i]
        idx = key_list.index(id_)
        solvables[idx].append(solv)

    for i in range(len(solvables)):
        if 1 in solvables[i]:
            solvables[i] = 1
        else:
            solvables[i] = 0

    with open(cfg.prediction_file, 'w', encoding='utf-8') as f:
        print(json.dumps(answers), file=f)
    results = get_score_from_trivia(cfg, cfg.test_set)
    exact = results['exact_match']
    f1 = results['f1']

    total_action_num = 0
    for i in range(cfg.memory_num):
        total_action_num += full_action[i]
    avg_action = [0 for _ in range(cfg.memory_num)]
    for i in range(cfg.memory_num):
        avg_action[i] += full_action[i] / total_action_num
    print('All processes is finished.')
    print('ExactMatch: %.2f' % (sum(exact_list) / len(exact_list) * 100))
    print('F1score: %.2f' % (sum(f1_list) / len(f1_list) * 100))
    print()
    print('ExactMatch: %.2f' % (exact * 100))
    print('F1score: %.2f' % (f1 * 100))
    print()
    print('Solvables: %.2f' % (sum(full_solvable) / len(full_solvable) * 100))
    print('Non duplicated Solvables: %.2f' %
          (sum(solvables) / len(solvables) * 100))
    print()
    print('Total number of actions: %d' % (total_action_num))
    for i in range(cfg.memory_num):
        print('Action %d : %.2f' % (i, avg_action[i] * 100))
Exemplo n.º 7
0
def main(cfg):
    ckpt = None
    if cfg.ckpt:
        if not os.path.exists(cfg.ckpt):
            print('Invalid ckpt path:', cfg.ckpt)
            exit(1)
        ckpt = torch.load(cfg.ckpt, map_location=lambda storage, loc: storage)

        print(cfg.ckpt, 'loaded')
        loaded_cfg = ckpt['cfg'].__dict__
        pprint(loaded_cfg)

        del loaded_cfg['num_episodes']
        del loaded_cfg['num_workers']
        del loaded_cfg['test_set']
        del loaded_cfg['pre_ckpt']

        cfg.__dict__.update(loaded_cfg)
        cfg.model = cfg.model.upper()

        print()
        print('Merged Config')
        pprint(cfg.__dict__)
    else:
        os.makedirs(os.path.join(cfg.log_dir, 'ckpt'))

    prepro_dir = os.path.join(cfg.prepro_dir, 'task%s' % (cfg.task_id))
    with open(os.path.join(prepro_dir, 'vocab.pk'), 'rb') as f:
        vocab = pickle.load(f)

    with open(os.path.join(prepro_dir, 'stats.pk'), 'rb') as f:
        stats = pickle.load(f)
        stats['max_ques_len'] = stats['max_sent_len']

    shared_model = create_a3c_model(cfg, vocab, stats)

    if cfg.pre_ckpt is not None:
        pretrain_param = torch.load(cfg.pre_ckpt,
                                    map_location=lambda storage, loc: storage)
        pretrain_param = pretrain_param['model']
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        new_pretrain_param = pretrain_param.copy()
        pretrain_param = new_pretrain_param.copy()

        metadata = getattr(pretrain_param, '_metadata', None)
        if metadata is not None:
            pretrain_param._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(
                prefix[:-1], {})
            module._load_from_state_dict(pretrain_param, prefix,
                                         local_metadata, True, missing_keys,
                                         unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')

        load(shared_model, prefix='')
        print("Weights of {} not initialized from pretrained model: {}".format(
            shared_model.__class__.__name__, missing_keys))
        print("Weights from pretrained model not used in {}: {}".format(
            shared_model.__class__.__name__, unexpected_keys))

    if ckpt is not None:
        shared_model.load_state_dict(ckpt['model'])
    shared_model.share_memory()

    params = filter(lambda p: p.requires_grad, shared_model.parameters())
    optim = SharedAdam(params, lr=cfg.lr)

    if ckpt is not None:
        optim.load_state_dict(ckpt['optim'])
    optim.share_memory()

    set_seed(cfg.seed)

    done = mp.Value('i', False)
    if ckpt is not None:
        gstep = mp.Value('i', ckpt['step'])
    else:
        gstep = mp.Value('i', 0)
    queue = mp.Queue()

    train_env = create_env(cfg, 'train', vocab, stats, shuffle=True)
    valid_shuffle = False if cfg.num_valid_episodes == 0 else True
    valid_env = create_env(cfg, 'valid', vocab, stats, shuffle=valid_shuffle)

    procs = []
    if cfg.debug:
        p = TrainWorker(cfg, len(procs), done, shared_model, optim, vocab,
                        stats, train_env, queue, gstep)
        # p = ValidWorker(cfg, len(procs), done, shared_model, optim, vocab, stats, valid_env, gstep)
        p.run()
        return

    p = ValidWorker(cfg, len(procs), done, shared_model, optim, vocab, stats,
                    valid_env, gstep)
    p.start()
    procs.append(p)

    for _ in range(cfg.num_workers - 1):
        p = TrainWorker(cfg, len(procs), done, shared_model, optim, vocab,
                        stats, train_env, queue, gstep)
        p.start()
        procs.append(p)

    p = TensorboardWorker(cfg, len(procs), queue, done, gstep)
    p.start()
    procs.append(p)

    for p in procs:
        p.join()
    print('All processes is finished:', cfg.log_dir)