Esempio n. 1
0
def test(model, ema, args, data):
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    loss = 0
    answers = dict()
    model.eval()

    backup_params = EMA(0)
    for name, param in model.named_parameters():
        if param.requires_grad:
            backup_params.register(name, param.data)
            param.data.copy_(ema.get(name))


    total_time = 0 
    previous_time = time.time()
    for batch in iter(data.dev_iter):
        #time1 = time.time()
        with torch.no_grad():
            p1, p2 = model(batch.c_char,batch.q_char,batch.c_word[0],batch.q_word[0],batch.c_word[1],batch.q_word[1])
        #p1, p2 = model(batch)
        #time2 = time.time()
        #total_time = total_time + time2 - time1
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        loss += batch_loss.item()

        # (batch, c_len, c_len)
        batch_size, c_len = p1.size()
        ls = nn.LogSoftmax(dim=1)
        mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
        score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
        score, s_idx = score.max(dim=1)
        score, e_idx = score.max(dim=1)
        s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()

        for i in range(batch_size):
            id = batch.id[i]
            answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1]
            answer = ' '.join([data.CONTEXT_WORD.vocab.itos[idx] for idx in answer])
            if answer == "<eos>":
                answer = ""
            answers[id] = answer
    #print(f'one epoch time {time.time()-previous_time}')
    #print(f'total time {total_time}')

    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data.copy_(backup_params.get(name))

    with open(args.prediction_file, 'w', encoding='utf-8') as f:
        print(json.dumps(answers), file=f)

    opts = evaluate.parse_args(args=[f"{args.dataset_file}", f"{args.prediction_file}" ])     

    results = evaluate.main(opts)
    return loss, results['exact'], results['f1'], results['HasAns_exact'], results['HasAns_f1'], results['NoAns_exact'], results['NoAns_f1']
Esempio n. 2
0
 def adapt(self, num_classes):
     '''
     To allow adapting the model to a different dataset with the same semantic classifier weights
     num_classes: number of classes in the target dataset
     return: None
     '''
     if isinstance(self.model, torch.nn.DataParallel):
         self.model.module.adapt(num_classes)
     else:
         self.model.adapt(num_classes)
     self.model.to(self.device)
     self.ema = EMA(self.model, self.config.ema_alpha)
Esempio n. 3
0
def test(model, ema, args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    loss = 0
    answers = dict()
    model.eval()

    backup_params = EMA(0)
    for name, param in model.named_parameters():
        if param.requires_grad:
            backup_params.register(name, param.data)
            param.data.copy_(ema.get(name))

    with torch.set_grad_enabled(False):
        for batch in iter(data.dev_iter):
            p1, p2 = model(batch)
            batch_loss = criterion(p1, batch.s_idx) + criterion(
                p2, batch.e_idx)
            loss += batch_loss.item()

            # (batch, c_len, c_len)
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) *
                    float('-inf')).to(device).tril(-1).unsqueeze(0).expand(
                        batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()

            for i in range(batch_size):
                id = batch.id[i]
                answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1]
                answer = ' '.join(
                    [data.WORD.vocab.itos[idx] for idx in answer])
                answers[id] = answer

        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data.copy_(backup_params.get(name))

    #print(answers)

    with open(args.prediction_file, 'w', encoding='utf-8') as f:
        print(json.dumps(answers, indent=4), file=f)

    results = evaluate.main(args, answers, data)
    return loss / len(data.dev_iter), results['exact_match'], results['f1']
Esempio n. 4
0
    def load_model_state(self, chkpt_dict_path):
        '''
        Loads model state based on a checkpoint saved by SemCo _save_checkpoint() function.
        '''
        print("Loading Model State")
        checkpoint_dict = torch.load(chkpt_dict_path, map_location=self.device)
        if 'model_state_dict' in checkpoint_dict:
            state_dict = checkpoint_dict['model_state_dict']
        else:
            print('model_state_dict key is not present in checkpoint, loading pretrained model failed, using original initialization for model')
            return

        # handle state_dictionaries where keys has 'module' in them (if the model was wrapped in nn.DataParallel)
        if all(['module' in key for key in state_dict.keys()]):
            if all(['module' in key for key in self.model.state_dict()]):
                pass
            else:
                state_dict= {k.replace('module.',''):v for k,v in state_dict.items()}
                if 'ema_shadow' in checkpoint_dict:
                    checkpoint_dict['ema_shadow'] = {k.replace('module.',''):v for k,v in checkpoint_dict['ema_shadow'].items()}

        try:
            self.model.load_state_dict(state_dict)
        except Exception as e:
            print(f'Problem occurred during naive state_dict loading: {e}.\nTrying to only load common params')
            try:
                model_state= self.model.state_dict()
                pretrained_state = {k:v for k,v in state_dict.items() if k in model_state and v.size() == model_state[k].size()}
                unloaded_state = set(list(state_dict.keys())) - set(list(model_state.keys()))
                model_state.update(pretrained_state)
                self.model.load_state_dict(model_state)
                print(f'Success. Following params in  pretrained_state_dict were not loaded: {unloaded_state}')
            except Exception as e:
                print(f'Unable to load model state due to following error. Model will be initialised randomly. \n {e}')
        if 'ema_shadow' in checkpoint_dict:
            try:
                self.ema = EMA(self.model, self.config.ema_alpha)
                similar_params = {k:v for k,v in checkpoint_dict['ema_shadow'].items() if k in self.ema.shadow and v.size() == self.ema.shadow[k].size()}
                self.ema.shadow.update(similar_params)
                print(f'EMA shadow has been loaded successfully. {len(similar_params)} out of {len(self.ema.shadow)} params were loaded')
            except Exception as e:
                print(f'Unable to load EMA shadow. EMA will be reinitialised with current model params. {e}')
                self.ema = EMA(self.model, self.config.ema_alpha)
        else:
            print('EMA shadow is not found in checkpoint dictionary. EMA will be reinitialised with current model params.')
            self.ema = EMA(self.model, self.config.ema_alpha)
        try:
            if 'classes' in checkpoint_dict:
                classes = self.dataset_meta['classes']
                classes_model = checkpoint_dict['classes']
                if all([classes_model[i] == classes[i] for i in range(len(classes))]):
                    print(f'classes matched successfully')
                else:
                    print(
                        "Classes loaded don't match the classes used while training the model, output of softmax can't be trusted")
        except Exception as e:
            print("can't load classes file. Pls check and try again.")
        return
Esempio n. 5
0
 def __init__(self, config, dataset_meta, device, L='dynamic', device_ids=None):
     self.config = config
     self.dataset_meta = dataset_meta
     if 'stats' not in dataset_meta:
         self.dataset_meta['stats'] = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)  # imagenet_stats
     self.device = device
     self.parallel = config.parallel
     self.device_ids = device_ids
     self.L = L
     self.label_emb_guessor, self.emb_dim = self._get_label_guessor()
     self.model = self._set_model(config.parallel, device_ids)
     self.optim = self._get_optimiser(self.config)
     if not self.config.no_amp:
         from apex import amp
         self.model, self.optim = amp.initialize(self.model, self.optim, opt_level="O1")
     if self.config.parallel:
         self.model = nn.DataParallel(self.model)
     # initialise the exponential moving average model
     self.ema = EMA(self.model, self.config.ema_alpha)
     if self.config.use_pretrained:
         self.load_model_state(config.checkpoint_path)
     if self.config.freeze_backbone:
         self._freeze_model_backbone()
     self.logger, self.writer, self.time_stamp = self._setup_default_logging()
Esempio n. 6
0
def cw_tree_attack_targeted():
    cw = CarliniL2_qa(debug=args.debugging)
    criterion = nn.CrossEntropyLoss()
    loss = 0
    tot = 0
    adv_loss = 0
    targeted_success = 0
    untargeted_success = 0
    adv_text = []
    answers = dict()
    adv_answers = dict()
    # model.eval()

    embed = torch.load(args.word_vector)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    vocab = Vocab(filename=args.dictionary,
                  data=[PAD_WORD, UNK_WORD, EOS_WORD, SOS_WORD])
    generator = Generator(args.test_data, vocab=vocab, embed=embed)
    transfered_embedding = torch.load('bidaf_transfered_embedding.pth')
    transfer_emb = torch.nn.Embedding.from_pretrained(transfered_embedding).to(
        device)
    seqback = WrappedSeqback(embed,
                             device,
                             attack=True,
                             seqback_model=generator.seqback_model,
                             vocab=vocab,
                             transfer_emb=transfer_emb)
    treelstm = generator.tree_model
    generator.load_state_dict(torch.load(args.load_ae))

    backup_params = EMA(0)
    for name, param in model.named_parameters():
        if param.requires_grad:
            backup_params.register(name, param.data)
            param.data.copy_(ema.get(name))

    class TreeModel(nn.Module):
        def __init__(self):
            super(TreeModel, self).__init__()
            self.inputs = None

        def forward(self, hidden):
            self.embedding = seqback(hidden)
            return model(batch, perturbed=self.embedding)

        def set_temp(self, temp):
            seqback.temp = temp

        def get_embedding(self):
            return self.embedding

        def get_seqback(self):
            return seqback

    tree_model = TreeModel()
    for batch in tqdm(iter(data.dev_iter), total=1000):
        p1, p2 = model(batch)
        orig_answer, orig_s_idx, orig_e_idx = write_to_ans(
            p1, p2, batch, answers)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        loss += batch_loss.item()

        append_info = append_input(batch, vocab)
        batch_add_start = append_info['add_start']
        batch_add_end = append_info['add_end']
        batch_start_target = torch.LongTensor(
            append_info['target_start']).to(device)
        batch_end_target = torch.LongTensor(
            append_info['target_end']).to(device)
        add_sents = append_info['append_sent']

        input_embedding = model.word_emb(batch.c_word[0])
        append_info['tree'] = [generator.get_tree(append_info['tree'])]
        seqback.sentences = input_embedding.clone().detach()
        seqback.batch_trees = append_info['tree']
        seqback.batch_add_sent = append_info['ae_sent']
        seqback.start = append_info['add_start']
        seqback.end = append_info['add_end']
        seqback.adv_sent = []

        batch_tree_embedding = []
        for bi, append_sent in enumerate(append_info['ae_sent']):
            seqback.target_start = append_info['target_start'][
                0] - append_info['add_start'][0]
            seqback.target_end = append_info['target_end'][0] - append_info[
                'add_start'][0]
            sentences = [
                torch.tensor(append_sent, dtype=torch.long, device=device)
            ]
            seqback.target = sentences[0][seqback.
                                          target_start:seqback.target_end + 1]
            trees = [append_info['tree'][bi]]
            tree_embedding = treelstm(sentences, trees)[0][0].detach()
            batch_tree_embedding.append(tree_embedding)
        hidden = torch.cat(batch_tree_embedding, dim=0)
        cw.batch_info = append_info
        cw.num_classes = append_info['tot_length']

        adv_hidden = cw.run(tree_model,
                            hidden, (batch_start_target, batch_end_target),
                            input_token=input_embedding)
        seqback.adv_sent = []

        # re-test
        for bi, (add_start,
                 add_end) in enumerate(zip(batch_add_start, batch_add_end)):
            if bi in cw.o_best_sent:
                ae_words = cw.o_best_sent[bi]
                bidaf_tokens = bidaf_convert_to_idx(ae_words)
                batch.c_word[0].data[bi, add_start:add_end] = torch.LongTensor(
                    bidaf_tokens)
        p1, p2 = model(batch)
        adv_answer, adv_s_idx, adv_e_idx = write_to_ans(
            p1, p2, batch, adv_answers)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        adv_loss += batch_loss.item()

        for bi, (start_target, end_target) in enumerate(
                zip(batch_start_target, batch_end_target)):
            start_output = adv_s_idx
            end_output = adv_e_idx
            targeted_success += int(
                compare(start_output, start_target.item(), end_output,
                        end_target.item()))
            untargeted_success += int(
                compare_untargeted(start_output, start_target.item(),
                                   end_output, end_target.item()))

        for i in range(len(add_sents)):
            logger.info(("orig:", transform(add_sents[i])))
            try:
                logger.info(("adv:", cw.o_best_sent[i]))
                adv_text.append({
                    'adv_text': cw.o_best_sent[i],
                    'qas_id': batch.id[i],
                    'adv_predict': (orig_s_idx, orig_e_idx),
                    'orig_predict': (adv_s_idx, adv_e_idx),
                    'Orig answer:': orig_answer,
                    'Adv answer:': adv_answer
                })
                joblib.dump(adv_text, root_dir + '/adv_text.pkl')
            except:
                adv_text.append({
                    'adv_text': transform(add_sents[i]),
                    'qas_id': batch.id[i],
                    'adv_predict': (orig_s_idx, orig_e_idx),
                    'orig_predict': (adv_s_idx, adv_e_idx),
                    'Orig answer:': orig_answer,
                    'Adv answer:': adv_answer
                })
                joblib.dump(adv_text, root_dir + '/adv_text.pkl')
                continue
        # for batch size = 1
        tot += 1
        logger.info(("orig predict", (orig_s_idx, orig_e_idx)))
        logger.info(("adv append predict", (adv_s_idx, adv_e_idx)))
        logger.info(("targeted successful rate:", targeted_success))
        logger.info(("untargetd successful rate:", untargeted_success))
        logger.info(("Orig answer:", orig_answer))
        logger.info(("Adv answer:", adv_answer))
        logger.info(("tot:", tot))

    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data.copy_(backup_params.get(name))

    with open(options.prediction_file, 'w', encoding='utf-8') as f:
        print(json.dumps(answers), file=f)
    with open(options.prediction_file + '_adv.json', 'w',
              encoding='utf-8') as f:
        print(json.dumps(adv_answers), file=f)
    results = evaluate.main(options)
    logger.info(tot)
    logger.info(("adv loss, results['exact_match'], results['f1']", loss,
                 results['exact_match'], results['f1']))
    return loss, results['exact_match'], results['f1']
Esempio n. 7
0
def cw_random_word_attack():
    cw = CarliniL2_untargeted_qa(debug=args.debugging)
    criterion = nn.CrossEntropyLoss()
    loss = 0
    adv_loss = 0
    targeted_success = 0
    untargeted_success = 0
    adv_text = []
    answers = dict()
    adv_answers = dict()

    backup_params = EMA(0)
    for name, param in model.named_parameters():
        if param.requires_grad:
            backup_params.register(name, param.data)
            param.data.copy_(ema.get(name))
    tot = 0
    for batch in tqdm(iter(data.dev_iter), total=1000):
        p1, p2 = model(batch)
        orig_answer, orig_s_idx, orig_e_idx = write_to_ans(
            p1, p2, batch, answers)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        loss += batch_loss.item()

        append_info = append_random_input(batch)
        allow_idxs = append_info['allow_idx']
        batch_start_target = torch.LongTensor([0]).to(device)
        batch_end_target = torch.LongTensor([0]).to(device)

        input_embedding = model.word_emb(batch.c_word[0])
        cw_mask = np.zeros(input_embedding.shape).astype(np.float32)
        cw_mask = torch.from_numpy(cw_mask).float().to(device)

        for bi, allow_idx in enumerate(allow_idxs):
            cw_mask[bi, np.array(allow_idx)] = 1
        cw.wv = model.word_emb.weight
        cw.inputs = batch
        cw.mask = cw_mask
        cw.batch_info = append_info
        cw.num_classes = append_info['tot_length']
        # print(transform(to_list(batch.c_word[0][0])))
        cw.run(model, input_embedding, (batch_start_target, batch_end_target))

        # re-test
        for bi, allow_idx in enumerate(allow_idxs):
            if bi in cw.o_best_sent:
                for i, idx in enumerate(allow_idx):
                    batch.c_word[0].data[bi, idx] = cw.o_best_sent[bi][i]
        p1, p2 = model(batch)
        adv_answer, adv_s_idx, adv_e_idx = write_to_ans(
            p1, p2, batch, adv_answers)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        adv_loss += batch_loss.item()

        for bi, (start_target, end_target) in enumerate(
                zip(batch_start_target, batch_end_target)):
            start_output = adv_s_idx
            end_output = adv_e_idx
            targeted_success += int(
                compare(start_output, start_target.item(), end_output,
                        end_target.item()))
            untargeted_success += int(
                compare_untargeted(start_output, start_target.item(),
                                   end_output, end_target.item()))
        for i in range(len(allow_idxs)):
            try:
                logger.info(("adv:", transform(cw.o_best_sent[i])))
                adv_text.append({
                    'added_text':
                    transform(cw.o_best_sent[i]),
                    'adv_text':
                    transform(to_list(batch.c_word[0][0])),
                    'qas_id':
                    batch.id[i],
                    'adv_predict': (orig_s_idx, orig_e_idx),
                    'orig_predict': (adv_s_idx, adv_e_idx),
                    'Orig answer:':
                    orig_answer,
                    'Adv answer:':
                    adv_answer
                })
                joblib.dump(adv_text, root_dir + '/adv_text.pkl')
            except:
                adv_text.append({
                    'adv_text':
                    transform(to_list(batch.c_word[0][0])),
                    'qas_id':
                    batch.id[i],
                    'adv_predict': (orig_s_idx, orig_e_idx),
                    'orig_predict': (adv_s_idx, adv_e_idx),
                    'Orig answer:':
                    orig_answer,
                    'Adv answer:':
                    adv_answer
                })
                joblib.dump(adv_text, root_dir + '/adv_text.pkl')
                continue
        # for batch size = 1
        tot += 1
        logger.info(("orig predict", (orig_s_idx, orig_e_idx)))
        logger.info(("adv append predict", (adv_s_idx, adv_e_idx)))
        logger.info(("targeted successful rate:", targeted_success))
        logger.info(("untargetd successful rate:", untargeted_success))
        logger.info(("Orig answer:", orig_answer))
        logger.info(("Adv answer:", adv_answer))
        logger.info(("tot:", tot))

    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data.copy_(backup_params.get(name))

    with open(options.prediction_file, 'w', encoding='utf-8') as f:
        print(json.dumps(answers), file=f)
    with open(options.prediction_file + '_adv.json', 'w',
              encoding='utf-8') as f:
        print(json.dumps(adv_answers), file=f)
    results = evaluate.main(options)
    logger.info(tot)
    logger.info(("adv loss, results['exact_match'], results['f1']", loss,
                 results['exact_match'], results['f1']))
    return loss, results['exact_match'], results['f1']
Esempio n. 8
0
    answer_append_sentences = joblib.load(
        'sampled_perturb_answer_sentences.pkl')
    question_append_sentences = joblib.load(
        'sampled_perturb_question_sentences.pkl')

    model = BiDAF(options, data.WORD.vocab.vectors).to(device)
    if options.old_model is not None:
        model.load_state_dict(
            torch.load(options.old_model,
                       map_location="cuda:{}".format(options.gpu)))
    if options.old_ema is not None:
        # ema = pickle.load(open(options.old_ema, "rb"))
        ema = torch.load(options.old_ema, map_location=device)
    else:
        ema = EMA(options.exp_decay_rate)
        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.register(name, param.data)

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )
        else:
            torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)

    if args.model == 'word_attack':
Esempio n. 9
0
def train(args):
    db = Data(args)
    # db.build_vocab()  # 每次build_vocab,相同频数的字词id可能不同
    db.load_vocab()
    db.build_dataset()  # 得到train_loader

    model = BiDAF(args)
    if args.cuda:
        model = model.cuda()
    if args.ema:
        ema = EMA(0.999)
        print("Register EMA ...")
        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.register(name, param.data)
    init_lr = args.init_lr
    optimizer = torch.optim.Adam(params=model.parameters(), lr=init_lr)
    lr = init_lr

    batch_step = args.batch_step
    loss_fn = nn.CrossEntropyLoss()
    logger = Logger('./logs')
    step = 0

    valid_raw_article_list = db.valid_raw_article_list
    valid_answer_list = db.valid_answer_list

    print('========== Train ==============')

    for epoch in range(args.epoch_num):
        print('---Epoch', epoch, "lr:", lr)
        running_loss = 0.0
        count = 0
        print("len(db.train_loader):", len(db.train_loader))
        for article, question, answer_span, _ in db.train_loader:
            if args.cuda:
                article, question, answer_span = article.cuda(), question.cuda(
                ), answer_span.cuda()
            p1, p2 = model(article, question)
            loss_p1 = loss_fn(p1, answer_span.transpose(0, 1)[0])
            loss_p2 = loss_fn(p2, answer_span.transpose(0, 1)[1])
            running_loss += loss_p1.item()
            running_loss += loss_p2.item()

            optimizer.zero_grad()
            (loss_p1 + loss_p2).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            optimizer.step()
            if args.ema:
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        param.data = ema(name, param.data)

            count += 1
            if count % batch_step == 0:
                rep_str = '[{}] Epoch {}, loss: {:.3f}'
                print(
                    rep_str.format(
                        datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
                        epoch, running_loss / batch_step))

                info = {'loss': running_loss / batch_step}
                running_loss = 0.0
                count = 0

                # 1. Log scalar values (scalar summary)
                for tag, value in info.items():
                    logger.scalar_summary(tag, value, step + 1)

                # 2. Log values and gradients of the parameters (histogram summary)
                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    logger.histo_summary(tag,
                                         value.data.cpu().numpy(), step + 1)
                    logger.histo_summary(tag + '/grad',
                                         value.grad.data.cpu().numpy(),
                                         step + 1)
                step += 1

        # 验证集
        if args.with_valid:
            print('======== Epoch {} result ========'.format(epoch))
            print("len(db.valid_loader):", len(db.valid_loader))
            valid_result = []
            idx = 0
            for article, question, _ in db.valid_loader:
                if args.cuda:
                    article, question = article.cuda(), question.cuda()
                p1, p2 = model(article, question, is_trainning=False)

                _, p1_predicted = torch.max(p1.cpu().data, 1)
                _, p2_predicted = torch.max(p2.cpu().data, 1)
                p1_predicted = p1_predicted.numpy().tolist()
                p2_predicted = p2_predicted.numpy().tolist()
                for _p1, _p2, _raw_article, _answer in zip(
                        p1_predicted, p2_predicted,
                        valid_raw_article_list[idx:idx + len(p1_predicted)],
                        valid_answer_list[idx:idx + len(p1_predicted)]):
                    valid_result.append({
                        "ref_answer":
                        _answer,
                        "cand_answer":
                        "".join(_raw_article[_p1:_p2 + 1])
                    })
                idx = idx + len(p1_predicted)
            rouge_score = test_score(valid_result)
            info = {'rouge_score': rouge_score}

            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch + 1)

        lr = max(0.00001, init_lr * 0.9**(epoch + 1))
        print("lr:", lr)
        parameters = filter(lambda param: param.requires_grad,
                            model.parameters())
        optimizer = torch.optim.Adam(params=parameters,
                                     lr=lr,
                                     weight_decay=1e-7)

        # print(len(db.valid_loader))
        if epoch >= 1 and args.saved_model_file:
            torch.save(model.state_dict(),
                       args.saved_model_file + "_epoch_" + str(epoch))
            print("saved model")
Esempio n. 10
0
def train(args, data):
    device = torch.device(
        "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.WORD.vocab.vectors).to(device)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(log_dir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1

    iterator = data.train_iter
    for i, batch in enumerate(iterator):
        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch

        p1, p2 = model(batch)

        optimizer.zero_grad()
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        loss += batch_loss.item()
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('exact_match/dev', dev_exact, c)
            writer.add_scalar('f1/dev', dev_f1, c)
            print('train loss: {} / dev loss: {}'.format(loss, dev_loss) +
                  ' / dev EM: {} / dev F1: {}'.format(dev_exact, dev_f1))

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_exact = dev_exact
                best_model = copy.deepcopy(model)

            loss = 0
            model.train()

    writer.close()
    print('max dev EM: {} / max dev F1: {}'.format(max_dev_exact, max_dev_f1))

    return best_model
Esempio n. 11
0
def train(args):
    db = Data(args)
    # db.build_vocab()  # 每次build_vocab,相同频数的字词id可能不同
    db.load_vocab()
    db.build_dataset()  # 得到train_loader

    # model = BiDAF(args)
    model = SLQA(args)
    first_model = "./checkpoints/SLQA_elmo_epoch_0"
    model.load_state_dict(torch.load(first_model))
    if args.cuda:
        model = model.cuda()
    if args.ema:
        ema = EMA(0.999)
        print("Register EMA ...")
        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.register(name, param.data)
    init_lr = args.init_lr
    parameters = filter(lambda param: param.requires_grad, model.parameters())
    weight_decay = 1e-6
    weight_decay = 0
    optimizer = torch.optim.Adam(params=parameters,
                                 lr=init_lr,
                                 weight_decay=weight_decay)
    batch_step = args.batch_step
    loss_fn = nn.CrossEntropyLoss()
    logger = Logger('./logs')
    step = 0

    train_raw_article_list = db.train_raw_article_list
    train_raw_question_list = db.train_raw_question_list

    valid_raw_article_list = db.valid_raw_article_list
    valid_answer_list = db.valid_answer_list
    valid_raw_question_list = db.valid_raw_question_list

    # question_hdf5_f = h5py.File(args.question_hdf5_path, "r")
    # article_hdf5_f = h5py.File(args.article_hdf5_path, "r")
    print('========== Train ==============')
    for epoch in range(args.epoch_num):
        print('---Epoch', epoch)
        running_loss = 0.0
        count = 0
        print("len(db.train_loader):", len(db.train_loader))
        train_idx = 0
        for batch_id, (article, question, answer_span,
                       _) in enumerate(db.train_loader):
            if args.cuda:
                article, question, answer_span = article.cuda(), question.cuda(
                ), answer_span.cuda()
            # tmp_train_raw_article_list = train_raw_article_list[train_idx:train_idx + question.size()[0]]
            # tmp_train_raw_question_list = train_raw_question_list[train_idx:train_idx + question.size()[0]]
            # question_elmo = gen_elmo_by_text(question_hdf5_f, tmp_train_raw_question_list, args.max_question_len)
            # article_elmo = gen_elmo_by_text(article_hdf5_f, tmp_train_raw_article_list, args.max_article_len)
            # pickle.dump((article_elmo, question_elmo), open(elmo_save_path, "wb"))
            elmo_save_path = "/backup231/lhliu/jszn/elmo/" + str(
                batch_id) + ".pkl"
            article_elmo, question_elmo = pickle.load(
                open(elmo_save_path, "rb"))
            # print(elmo_save_path)
            article_elmo = torch.tensor(article_elmo, dtype=torch.float)
            question_elmo = torch.tensor(question_elmo, dtype=torch.float)
            # train_idx += question.size()[0]
            # continue
            if args.cuda:
                question_elmo = question_elmo.cuda()
                article_elmo = article_elmo.cuda()

            p1, p2 = model(article,
                           question,
                           article_elmo=article_elmo,
                           question_elmo=question_elmo)
            loss_p1 = loss_fn(p1, answer_span.transpose(0, 1)[0])
            loss_p2 = loss_fn(p2, answer_span.transpose(0, 1)[1])
            running_loss += loss_p1.item()
            running_loss += loss_p2.item()

            optimizer.zero_grad()
            (loss_p1 + loss_p2).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            optimizer.step()
            if args.ema:
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        param.data = ema(name, param.data)

            count += 1
            if count % batch_step == 0:
                rep_str = '[{}] Epoch {}, loss: {:.3f}'
                print(
                    rep_str.format(
                        datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
                        epoch, running_loss / batch_step))

                # info = {'loss': running_loss / batch_step}
                running_loss = 0.0
                count = 0

                # # 1. Log scalar values (scalar summary)
                # for tag, value in info.items():
                #     logger.scalar_summary(tag, value, step + 1)

                # # 2. Log values and gradients of the parameters (histogram summary)
                # for tag, value in model.named_parameters():
                #     tag = tag.replace('.', '/')
                #     logger.histo_summary(tag, value.data.cpu().numpy(), step + 1)
                #     logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), step + 1)
                step += 1
            # break
        # 验证集
        if args.with_valid:
            print('======== Epoch {} result ========'.format(epoch))
            print("len(db.valid_loader):", len(db.valid_loader))
            valid_result = []
            idx = 0
            for article, question, _ in db.valid_loader:
                if args.cuda:
                    article, question = article.cuda(), question.cuda()

                tmp_valid_raw_article_list = valid_raw_article_list[idx:idx +
                                                                    question.
                                                                    size()[0]]
                tmp_valid_raw_question_list = valid_raw_question_list[
                    idx:idx + question.size()[0]]
                question_elmo = gen_elmo_by_text(question_hdf5_f,
                                                 tmp_valid_raw_question_list,
                                                 args.max_question_len)
                article_elmo = gen_elmo_by_text(article_hdf5_f,
                                                tmp_valid_raw_article_list,
                                                args.max_article_len)
                if args.cuda:
                    question_elmo = question_elmo.cuda()
                    article_elmo = article_elmo.cuda()
                p1, p2 = model(article,
                               question,
                               article_elmo,
                               question_elmo,
                               is_training=False)

                _, p1_predicted = torch.max(p1.cpu().data, 1)
                _, p2_predicted = torch.max(p2.cpu().data, 1)
                p1_predicted = p1_predicted.numpy().tolist()
                p2_predicted = p2_predicted.numpy().tolist()
                assert question.size()[0] == len(p1_predicted)
                for _p1, _p2, _raw_article, _answer in zip(
                        p1_predicted, p2_predicted,
                        valid_raw_article_list[idx:idx + len(p1_predicted)],
                        valid_answer_list[idx:idx + len(p1_predicted)]):
                    valid_result.append({
                        "ref_answer":
                        _answer,
                        "cand_answer":
                        "".join(_raw_article[_p1:_p2 + 1])
                    })
                idx = idx + len(p1_predicted)
            rouge_score = test_score(valid_result)
            info = {'rouge_score': rouge_score}

            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch + 1)
        #lr = init_lr
        lr = max(0.00001, init_lr * 0.9**(epoch + 1))  # 考虑是否使用
        print("lr:", lr)
        parameters = filter(lambda param: param.requires_grad,
                            model.parameters())
        optimizer = torch.optim.Adam(params=parameters,
                                     lr=lr,
                                     weight_decay=weight_decay)

        # print(len(db.valid_loader))
        if epoch >= 0 and args.saved_model_file:
            torch.save(model.state_dict(),
                       args.saved_model_file + "_epoch_" + str(epoch))
            print("saved model")
Esempio n. 12
0
def train(args, data):
    if args.load_model != "":
        model = BiDAF(args, data.WORD.vocab.vectors)
        model.load_state_dict(torch.load(args.load_model))
    else:
        model = BiDAF(args, data.WORD.vocab.vectors)
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    for name, i in model.named_parameters():
        if not i.is_leaf:
            print(name,i)

    writer = SummaryWriter(log_dir='runs/' + args.model_name)
    best_model = None

    for iterator, dev_iter, dev_file_name, index, print_freq, lr in zip(data.train_iter, data.dev_iter, args.dev_files, range(len(data.train)), args.print_freq, args.learning_rate):
        # print
        # (iterator[0])
        embed()
        exit(0)
        optimizer = optim.Adadelta(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        model.train()
        loss, last_epoch = 0, 0
        max_dev_exact, max_dev_f1 = -1, -1
        print(f"Training with {dev_file_name}")
        print()
        for i, batch in tqdm(enumerate(iterator), total=len(iterator) * args.epoch[index], ncols=100):
            present_epoch = int(iterator.epoch)
            eva = False
            if present_epoch == args.epoch[index]:
                break
            if present_epoch > last_epoch:
                print('epoch:', present_epoch + 1)
                eva = True
            last_epoch = present_epoch

            p1, p2 = model(batch)

            optimizer.zero_grad()
            batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
            loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()

            for name, param in model.named_parameters():
                if param.requires_grad:
                    ema.update(name, param.data)

            torch.cuda.empty_cache()
            if (i + 1) % print_freq == 0 or eva:
                dev_loss, dev_exact, dev_f1 = test(model, ema, args, data, dev_iter, dev_file_name)
                c = (i + 1) // print_freq

                writer.add_scalar('loss/train', loss, c)
                writer.add_scalar('loss/dev', dev_loss, c)
                writer.add_scalar('exact_match/dev', dev_exact, c)
                writer.add_scalar('f1/dev', dev_f1, c)
                print()
                print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
                      f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

                if dev_f1 > max_dev_f1:
                    max_dev_f1 = dev_f1
                    max_dev_exact = dev_exact
                    best_model = copy.deepcopy(model)

                loss = 0
                model.train()

    writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')
    print("testing with test batch on best model")
    test_loss, test_exact, test_f1 = test(best_model, ema, args, data, list(data.test_iter)[-1], args.test_files[-1])

    print(f'test loss: {test_loss:.3f}'
          f' / test EM: {test_exact:.3f} / test F1: {test_f1:.3f}')
    return best_model
Esempio n. 13
0
def train(args, data):
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.CONTEXT_WORD.vocab.vectors).to(device)
    
    num = count_parameters(model)
    print(f'paramter {num}')

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(log_dir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1
    print('totally {} epoch'.format(args.epoch))
    
    sys.stdout.flush()
    iterator = data.train_iter
    iterator.repeat = True
    for i, batch in enumerate(iterator):

        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            print('present_epoch value:',present_epoch)
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch

        p1, p2 = model(batch.c_char,batch.q_char,batch.c_word[0],batch.q_word[0],batch.c_word[1],batch.q_word[1])
        optimizer.zero_grad()
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        loss += batch_loss.item()
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_exact, dev_f1, dev_hasans_exact, dev_hasans_f1, dev_noans_exact,dev_noans_f1 = test(model, ema, args, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('exact_match/dev', dev_exact, c)
            writer.add_scalar('f1/dev', dev_f1, c)
            print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
                  f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}'
                  f' / dev hasans EM: {dev_hasans_exact} / dev hasans F1: {dev_hasans_f1}'
                  f' / dev noans EM: {dev_noans_exact} / dev noans F1: {dev_noans_f1}')

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_exact = dev_exact
                best_model = copy.deepcopy(model)

            loss = 0
            model.train() 
        sys.stdout.flush()
    writer.close()
    args.max_f1 = max_dev_f1
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model
Esempio n. 14
0
def train(args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.WORD.vocab.vectors).to(device)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(log_dir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1

    iterator = data.train_iter
    num_batch = len(iterator)
    for present_epoch in range(args.epoch):
        print('epoch', present_epoch + 1)
        for i, batch in enumerate(iterator):
            # present_epoch = int(iterator.epoch)
            """
            if present_epoch == args.epoch:
                print(present_epoch)
                print()
                print(args.epoch)
                break
            if present_epoch > last_epoch:
                print('epoch:', present_epoch + 1)
            last_epoch = present_epoch
            """

            p1, p2 = model(batch)

            optimizer.zero_grad()
            """
            print(p1)
            print()
            print(batch.s_idx)
            """

            if len(p1.size()) == 1:
                p1 = p1.reshape(1, -1)
            if len(p2.size()) == 1:
                p2 = p2.reshape(1, -1)
            batch_loss = criterion(p1, batch.s_idx) + criterion(
                p2, batch.e_idx)
            loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()

            for name, param in model.named_parameters():
                if param.requires_grad:
                    ema.update(name, param.data)

            best_model = copy.deepcopy(model)

            if i + 1 == num_batch:
                dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
                c = (i + 1) // args.print_freq

                writer.add_scalar('loss/train', loss / num_batch, c)
                writer.add_scalar('loss/dev', dev_loss, c)
                writer.add_scalar('exact_match/dev', dev_exact, c)
                writer.add_scalar('f1/dev', dev_f1, c)
                print(
                    f'train loss: {loss/num_batch:.3f} / dev loss: {dev_loss:.3f}'
                    f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

                if dev_f1 > max_dev_f1:
                    max_dev_f1 = dev_f1
                    max_dev_exact = dev_exact
                    best_model = copy.deepcopy(model)

                loss = 0
                model.train()

    writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model
Esempio n. 15
0
class SemCo:

    def __init__(self, config, dataset_meta, device, L='dynamic', device_ids=None):
        self.config = config
        self.dataset_meta = dataset_meta
        if 'stats' not in dataset_meta:
            self.dataset_meta['stats'] = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)  # imagenet_stats
        self.device = device
        self.parallel = config.parallel
        self.device_ids = device_ids
        self.L = L
        self.label_emb_guessor, self.emb_dim = self._get_label_guessor()
        self.model = self._set_model(config.parallel, device_ids)
        self.optim = self._get_optimiser(self.config)
        if not self.config.no_amp:
            from apex import amp
            self.model, self.optim = amp.initialize(self.model, self.optim, opt_level="O1")
        if self.config.parallel:
            self.model = nn.DataParallel(self.model)
        # initialise the exponential moving average model
        self.ema = EMA(self.model, self.config.ema_alpha)
        if self.config.use_pretrained:
            self.load_model_state(config.checkpoint_path)
        if self.config.freeze_backbone:
            self._freeze_model_backbone()
        self.logger, self.writer, self.time_stamp = self._setup_default_logging()

    def train(self, labelled_data, valid_data=None, training_config=None, save_best_model=False):
        """
        SemCo training function.
        labelled_data: dictionary holding labelled data in the form {'train/img1.png' : 'classA', ...}. This is relative
        to the dataset directory

        valid_data: dictionary holding validation data in the form {'train/img1.png' : 'classA', ...}. This is relative
        to the dataset directory

        training_config: to override parser config entirely if needed.

        save_best_model: if true, the best model (model state, ema state, optimizer state, classes) will be saved under
        './saved_models' directory.

        """
        # to allow overriding training params for different runs of training
        if training_config is None:
            training_config = self.config
        else:
            self.config = training_config

        L = len(labelled_data)
        n_iters_per_epoch, n_iters_all = self._init_training(training_config, L)
        # define criterion for upper(semantic embedding) and lower(discrete label) paths
        crit_lower = lambda inp, targ: F.cross_entropy(inp, targ, reduction='none')
        crit_upper = lambda inp, targ: 1 - F.cosine_similarity(inp, targ)
        optim = self.optim

        if n_iters_all == 0:
            n_iters_all = 1 # to avoid division by zero if we choose to set epochs to zero to skip a round
        lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0)

        num_workers = 0 if self.device == 'cpu' else training_config.num_workers_per_gpu * torch.cuda.device_count() if self.parallel else 4

        dltrain_x, dltrain_u = self._get_train_loaders(labelled_data, n_iters_per_epoch, num_workers, pin_memory=True,
                                                       cache_imgs=training_config.cache_imgs)
        print(f'Num of Labeled Training Data: {len(dltrain_x.dataset)}\nNum of Unlabeled Training Data:{len(dltrain_u.dataset)}')
        if valid_data:
            dlvalid = self._get_val_loader(valid_data, num_workers, pin_memory=True, cache_imgs=training_config.cache_imgs)
            print(f'Num of Validation Data: {len(dlvalid.dataset)}')

        train_args = dict(n_iters=n_iters_per_epoch, optim=optim, crit_lower=crit_lower,
                          crit_upper=crit_upper, lr_schdlr=lr_schdlr, dltrain_x=dltrain_x,
                          dltrain_u=dltrain_u)
        best_acc = -1
        best_epoch = 0
        best_loss = 1e6
        early_stopping_counter = 0
        best_metric = best_acc if training_config.es_metric == 'accuracy' else best_loss


        self.logger.info('-----------start training--------------')
        epochs_iterator = range(training_config.n_epoches) if not self.config.no_progress_bar else \
            tqdm(range(training_config.n_epoches),desc='Epoch')  # so that it displays the bar per epoch not per iteration
        for epoch in epochs_iterator:
            # training starts here
            train_loss, loss_x, loss_u, mask_mean, \
            loss_emb_x, loss_emb_u, mask_emb, mask_combined = \
                self._train_one_epoch(epoch, **train_args)
            if valid_data:
                top1, top5, valid_loss, top1_emb, top5_emb, top1_combined = self._evaluate(dlvalid, crit_lower)

            if valid_data:
                self.writer.add_scalars('train/1.loss', {'train': train_loss,
                                                         'test': valid_loss}, epoch)

            else:
                self.writer.add_scalar('train/1.loss', train_loss, epoch)
            self.writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
            self.writer.add_scalar('train/2.train_loss_emb_x', loss_emb_x, epoch)
            self.writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
            self.writer.add_scalar('train/3.train_loss_emb_u', loss_emb_u, epoch)
            self.writer.add_scalar('train/5.mask_mean', mask_mean, epoch)
            self.writer.add_scalar('train/5.mask_emb_mean', mask_emb, epoch)
            self.writer.add_scalar('train/5.mask_combined_mean', mask_combined, epoch)
            if valid_data:
                self.writer.add_scalars('test/1.test_acc', {'top1': top1, 'top5': top5, 'top1_emb': top1_emb,
                                                            'top5_emb': top5_emb, 'top1_combined': top1_combined},
                                        epoch)

                best_current = top1 if training_config.es_metric == 'accuracy' else valid_loss
                # only start looking for best model after min_wait period has expired
                if epoch >= training_config.min_wait_before_es:
                    isworse = lambda best,current: best <= current if training_config.es_metric == 'accuracy' else best >= current
                    if isworse(best_metric, best_current):
                        best_metric = best_current
                        best_epoch = epoch
                        if training_config.early_stopping_epochs:
                            best_model_state = self.model.state_dict()
                            best_ema_state = {k:v.clone().detach() for k,v in self.ema.shadow.items()}
                            early_stopping_counter = 0
                        if save_best_model:
                            try:
                                self._save_checkpoint()
                            except Exception as e:
                                print(f'Failed to save checkpoint: {e}')
                    elif training_config.early_stopping_epochs:
                        early_stopping_counter +=1
                else:
                    print('Minimum wait period still not expired. Leaving best epoch and best metric to default values')

                self.logger.info(
                    "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. Top1_emb: {:.4f}. Top5_emb: {:.4f}. Top1_comb: {:.4f}. best_metric: {:.4f} in epoch{}".
                        format(epoch, top1, top5, top1_emb, top5_emb, top1_combined, best_metric, best_epoch))
                # check if early stopping is to be activated
                if training_config.early_stopping_epochs and early_stopping_counter == training_config.early_stopping_epochs:
                    self.logger.info(f"Early stopping activated, loading best models and ending training. "
                                     f"{training_config.early_stopping_epochs} epochs with no improvement.")
                    self.model.load_state_dict(best_model_state)
                    self.ema.shadow = best_ema_state
                    break
            # this will only be activated in last epoch to decide whether best model should be loaded or not before ending training
            if epoch == training_config.n_epoches-1:
                self.logger.info(f"Break epoch is reached")
                # in case early stopping is configured, load best model before exiting (edge case for early stopping)
                if training_config.early_stopping_epochs and valid_data and epoch >= training_config.min_wait_before_es:
                    self.logger.info(f"Loading best model and ending training (since early stopping is set)")
                    self.model.load_state_dict(best_model_state)
                    self.ema.shadow = best_ema_state
        self.writer.close()

    def predict(self):
        num_work = 0 if self.device == 'cpu' else 4
        dataloader = self._get_test_loader(num_work, pin_memory=True, cache_imgs=self.config.cache_imgs)
        # using EMA params to evaluate performance
        self.ema.apply_shadow()
        self.ema.model.eval()
        self.ema.model.to(self.device)

        predictions = []
        with torch.no_grad():
            for ims in dataloader:
                ims = ims.to(self.device)
                logits, _, _ = self.ema.model(ims)
                probs = torch.softmax(logits, dim=1)
                scores, lbs_guess = torch.max(probs, dim=1)
                predictions.append(lbs_guess)
            predictions = torch.cat(predictions).cpu().detach().numpy()

        predictions = [self.dataset_meta['classes'][elem] for elem in predictions]
        filenames = [name.split('/')[-1] for name in dataloader.dataset.data]
        df = pd.DataFrame({'id': filenames, 'class': predictions})

        # note roll back model current params to continue training
        self.ema.restore()

        return df

    def load_model_state(self, chkpt_dict_path):
        '''
        Loads model state based on a checkpoint saved by SemCo _save_checkpoint() function.
        '''
        print("Loading Model State")
        checkpoint_dict = torch.load(chkpt_dict_path, map_location=self.device)
        if 'model_state_dict' in checkpoint_dict:
            state_dict = checkpoint_dict['model_state_dict']
        else:
            print('model_state_dict key is not present in checkpoint, loading pretrained model failed, using original initialization for model')
            return

        # handle state_dictionaries where keys has 'module' in them (if the model was wrapped in nn.DataParallel)
        if all(['module' in key for key in state_dict.keys()]):
            if all(['module' in key for key in self.model.state_dict()]):
                pass
            else:
                state_dict= {k.replace('module.',''):v for k,v in state_dict.items()}
                if 'ema_shadow' in checkpoint_dict:
                    checkpoint_dict['ema_shadow'] = {k.replace('module.',''):v for k,v in checkpoint_dict['ema_shadow'].items()}

        try:
            self.model.load_state_dict(state_dict)
        except Exception as e:
            print(f'Problem occurred during naive state_dict loading: {e}.\nTrying to only load common params')
            try:
                model_state= self.model.state_dict()
                pretrained_state = {k:v for k,v in state_dict.items() if k in model_state and v.size() == model_state[k].size()}
                unloaded_state = set(list(state_dict.keys())) - set(list(model_state.keys()))
                model_state.update(pretrained_state)
                self.model.load_state_dict(model_state)
                print(f'Success. Following params in  pretrained_state_dict were not loaded: {unloaded_state}')
            except Exception as e:
                print(f'Unable to load model state due to following error. Model will be initialised randomly. \n {e}')
        if 'ema_shadow' in checkpoint_dict:
            try:
                self.ema = EMA(self.model, self.config.ema_alpha)
                similar_params = {k:v for k,v in checkpoint_dict['ema_shadow'].items() if k in self.ema.shadow and v.size() == self.ema.shadow[k].size()}
                self.ema.shadow.update(similar_params)
                print(f'EMA shadow has been loaded successfully. {len(similar_params)} out of {len(self.ema.shadow)} params were loaded')
            except Exception as e:
                print(f'Unable to load EMA shadow. EMA will be reinitialised with current model params. {e}')
                self.ema = EMA(self.model, self.config.ema_alpha)
        else:
            print('EMA shadow is not found in checkpoint dictionary. EMA will be reinitialised with current model params.')
            self.ema = EMA(self.model, self.config.ema_alpha)
        try:
            if 'classes' in checkpoint_dict:
                classes = self.dataset_meta['classes']
                classes_model = checkpoint_dict['classes']
                if all([classes_model[i] == classes[i] for i in range(len(classes))]):
                    print(f'classes matched successfully')
                else:
                    print(
                        "Classes loaded don't match the classes used while training the model, output of softmax can't be trusted")
        except Exception as e:
            print("can't load classes file. Pls check and try again.")
        return

    def adapt(self, num_classes):
        '''
        To allow adapting the model to a different dataset with the same semantic classifier weights
        num_classes: number of classes in the target dataset
        return: None
        '''
        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.adapt(num_classes)
        else:
            self.model.adapt(num_classes)
        self.model.to(self.device)
        self.ema = EMA(self.model, self.config.ema_alpha)

    def _evaluate(self, dataloader, criterion):

        # using EMA params to evaluate performance
        self.ema.apply_shadow()
        self.ema.model.eval()
        self.ema.model.to(self.device)

        loss_meter = AverageMeter()
        top1_meter = AverageMeter()
        top5_meter = AverageMeter()
        top1_emb_meter = AverageMeter()
        top5_emb_meter = AverageMeter()
        top1_combined_meter = AverageMeter()

        with torch.no_grad():
            for ims, lbs in dataloader:
                ims = ims.to(self.device)
                lbs = lbs.to(self.device)
                logits, logits_emb, _ = self.ema.model(ims)
                sim = F.cosine_similarity(logits_emb.unsqueeze(1), self.label_emb_guessor.embedding_matrix.unsqueeze(0),
                                          dim=-1)
                sim = sim * self.label_emb_guessor.sharpening_factor
                loss = criterion(logits, lbs).mean()
                scores_emb = torch.softmax(sim, -1)
                scores = torch.softmax(logits, dim=1)
                top1, top5 = accuracy(scores, lbs, (1, 5))
                top1_emb, top5_emb = accuracy(scores_emb, lbs, (1, 5))
                scores_combined = torch.mean(torch.stack([scores_emb, scores]), dim=0)
                top1_combined, _ = accuracy(scores_combined, lbs, (1, 5))
                loss_meter.update(loss.item())
                top1_meter.update(top1.item())
                top5_meter.update(top5.item())
                top1_emb_meter.update(top1_emb.item())
                top5_emb_meter.update(top5_emb.item())
                top1_combined_meter.update(top1_combined.item())

        # note roll back model current params to continue training
        self.ema.restore()
        return top1_meter.avg, top5_meter.avg, loss_meter.avg, top1_emb_meter.avg, top5_emb_meter.avg, top1_combined_meter.avg

    def _set_model(self, parallel, device_ids):
        classes = self.dataset_meta['classes']
        n = len(classes)
        if self.config.model_backbone is not None:
            if self.config.model_backbone == 'wres':
                model = WideResnetWithEmbeddingHead(num_classes=n, k=self.config.wres_k, n=28, emb_dim=self.emb_dim)
            elif self.config.model_backbone == 'resnet18':
                model = ResNet18WithEmbeddingHead(num_classes=n, emb_dim=self.emb_dim,
                                                  pretrained=not self.config.no_imgnet_pretrained)
            elif self.config.model_backbone == 'resnet50':
                model = ResNet50WithEmbeddingHead(num_classes=n, emb_dim=self.emb_dim,
                                                  pretrained=not self.config.no_imgnet_pretrained)
        # if no backbone is passed in args, auto infer based on im size
        elif self.config.im_size <= 64:
            model = WideResnetWithEmbeddingHead(num_classes=n, k=self.config.wres_k, n=28, emb_dim=self.emb_dim)
        else:
            model = ResNet50WithEmbeddingHead(num_classes=n, emb_dim=self.emb_dim,
                                              pretrained=not self.config.no_imgnet_pretrained)

        model.to(self.device)

        return model

    def _freeze_model_backbone(self):
        for name, param in self.model.named_parameters():
            if 'fc_emb' in name or 'fc_classes' in name:
                param.requires_grad = True
                print(f'{name} parameter is unfrozen')
            else:
                param.requires_grad = False
        print('All remaining parameters are frozen.')

    def _train_one_epoch(self, epoch, n_iters, optim, crit_lower, crit_upper,
                         lr_schdlr, dltrain_x, dltrain_u):

        # note: _x denotes supervised and _u denotes unsupervised
        # note: when suffix '_emb' is appended to variable, it denotes same variable but for upper path

        # Renaming for consistency
        criteria_x = crit_lower
        criteria_u = crit_lower
        criteria_x_emb = crit_upper
        criteria_u_emb = crit_upper
        if not self.config.no_amp:
            from apex import amp

        self.model.train()
        loss_meter = AverageMeter()
        loss_x_meter = AverageMeter()
        loss_u_meter = AverageMeter()
        loss_emb_x_meter = AverageMeter()
        loss_emb_u_meter = AverageMeter()
        # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples
        n_strong_aug_meter = AverageMeter()
        max_score = AverageMeter()
        max_score_emb = AverageMeter()
        mask_meter = AverageMeter()
        mask_emb_meter = AverageMeter()
        mask_combined_meter = AverageMeter()

        epoch_start = time.time()  # start time
        dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
        iterator = range(n_iters) if self.config.no_progress_bar else tqdm(range(n_iters), desc='Epoch {}'.format(epoch))
        for it in iterator:
            ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
            ims_u_weak, ims_u_strong = next(dl_u)
            lbs_x = lbs_x.to(self.device)

            bt = ims_x_weak.size(0)
            mu = int(ims_u_weak.size(0) // bt)
            imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).to(self.device)
            imgs = interleave(imgs, 2 * mu + 1)
            logits, logits_emb, _ = self.model(imgs)
            del imgs
            logits = de_interleave(logits, 2 * mu + 1)
            logits_x = logits[:bt]
            logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu)
            del logits

            logits_emb = de_interleave(logits_emb, 2 * mu + 1)
            logits_emb__x = logits_emb[:bt]
            logits_emb_u_w, logits_emb_u_s = torch.split(logits_emb[bt:], bt * mu)
            del logits_emb

            # supervised loss for upper and lower paths
            loss_x = criteria_x(logits_x, lbs_x).mean()
            loss_x_emb = criteria_x_emb(logits_emb__x, self.label_emb_guessor.embedding_matrix[lbs_x]).mean()

            # guessing the labels for upper and lower paths
            with torch.no_grad():
                probs = torch.softmax(logits_u_w, dim=1)
                scores, lbs_u_guess = torch.max(probs, dim=1)
                mask = scores.ge(self.config.thr).float()
                # get label guesses and mask based on embedding predictions (upper path)
                lbs_emb_u_guess, mask_emb, scores_emb, lbs_guess_help = self.label_emb_guessor(logits_emb_u_w)

            # combining the losses via co-training (blind version)
            mask_combined = mask.bool() | mask_emb.bool()
            # each loss path will have two components (co-training implementation)
            loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() + \
                     (criteria_u(logits_u_s, lbs_guess_help) * mask_emb).mean() * (self.config.lambda_emb) / 3

            loss_u_emb = (criteria_u_emb(logits_emb_u_s, lbs_emb_u_guess) * mask_emb).mean() + \
                         (criteria_u_emb(logits_emb_u_s,
                                         self.label_emb_guessor.embedding_matrix[lbs_u_guess]) * mask).mean()

            loss_lower = loss_x + self.config.lam_u * loss_u
            loss_upper = loss_x_emb + self.config.lam_u * loss_u_emb
            loss = loss_lower + self.config.lambda_emb * loss_upper

            optim.zero_grad()
            if not self.config.no_amp:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optim.step()
            self.ema.update_params()
            lr_schdlr.step()

            loss_meter.update(loss.item())
            loss_x_meter.update(loss_x.item())
            loss_u_meter.update(loss_u.item())
            mask_meter.update(mask.mean().item())
            n_strong_aug_meter.update(mask_emb.sum().item())
            max_score.update(scores.mean())
            max_score_emb.update(scores_emb.mean())
            loss_emb_x_meter.update(loss_x_emb.item())
            loss_emb_u_meter.update(loss_u_emb.item())
            mask_combined_meter.update(mask_combined.float().mean().item())
            mask_emb_meter.update(mask_emb.mean().item())

            if (it + 1) % 512 == 0:
                t = time.time() - epoch_start

                lr_log = [pg['lr'] for pg in optim.param_groups]
                lr_log = sum(lr_log) / len(lr_log)

                self.logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. max_score:{:.4f}. "
                                 " Mask:{:.4f} loss_u_emb:{:.4f}. loss_x_emb:{:.4f}. mask_emb:{:.4f}. max_score_emb:{:.4f}. mask_emb_count:{:.4f}. mask_combined:{:.4f}. . LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, max_score.avg,
                    mask_meter.avg, loss_emb_u_meter.avg, loss_emb_x_meter.avg, mask_emb_meter.avg, max_score_emb.avg,
                    n_strong_aug_meter.avg, mask_combined_meter.avg, lr_log, t))

                epoch_start = time.time()

        self.ema.update_buffer()
        return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, mask_meter.avg, \
               loss_emb_x_meter.avg, loss_emb_u_meter.avg, mask_emb_meter.avg, mask_combined_meter.avg

    def _get_train_loaders(self, labelled_data, n_iters_per_epoch, num_workers, pin_memory, cache_imgs):
        mean, std = self.dataset_meta['stats']
        kwargs = dict(dataset_path=self.config.dataset_path, classes=self.dataset_meta['classes'],
                      labelled_data=labelled_data, batch_size=self.config.batch_size, mu=self.config.mu,
                      n_iters_per_epoch=n_iters_per_epoch, size=self.config.im_size, cropsize=self.config.cropsize,
                      mean=mean, std=std, num_workers=num_workers, pin_memory=pin_memory, cache_imgs=cache_imgs)

        return get_train_loaders(**kwargs)

    def _get_val_loader(self, valid_data, num_workers, pin_memory, cache_imgs):
        mean, std = self.dataset_meta['stats']
        kwargs = dict(dataset_path=self.config.dataset_path, classes=self.dataset_meta['classes'],
                      labelled_data=valid_data, batch_size=3 * self.config.batch_size,
                      size=self.config.im_size, cropsize=self.config.cropsize,
                      mean=mean, std=std, num_workers=num_workers, pin_memory=pin_memory, cache_imgs=cache_imgs)

        return get_val_loader(**kwargs)

    def _get_test_loader(self, num_workers, pin_memory, cache_imgs):
        mean, std = self.dataset_meta['stats']
        kwargs = dict(dataset_path=self.config.dataset_path, classes=self.dataset_meta['classes'],
                      batch_size=3 * self.config.batch_size, size=self.config.im_size,
                      cropsize=self.config.cropsize, mean=mean, std=std, num_workers=num_workers,
                      pin_memory=pin_memory, cache_imgs=cache_imgs)

        return get_test_loader(**kwargs)

    def _get_label_guessor(self):
        classes = self.dataset_meta['classes']
        class_2_embeddings_dict = get_labels2wv_dict(classes, self.config.word_vec_path)
        emb_dim = len(list(class_2_embeddings_dict.values())[0])
        if self.config.eps is None:
            eps = 0.15 if emb_dim < 100 else 0.2 if emb_dim < 256 else 0.28  # for label grouping clustering
        else:
            eps = self.config.eps
        label_group_idx, gr_mapping = get_grouping(class_2_embeddings_dict, eps=eps, return_mapping=True)
        label_guessor = LabelEmbeddingGuessor(classes, label_group_idx, class_2_embeddings_dict, self.config.thr_emb,
                                              self.device)
        return label_guessor, emb_dim

    def _setup_default_logging(self, default_level=logging.INFO):

        format = "%(asctime)s - %(levelname)s - %(name)s -   %(message)s"
        dataset_name = get_dataset_name(self.config.dataset_path)
        output_dir = os.path.join(dataset_name, f'x{self.L}')
        os.makedirs(output_dir, exist_ok=True)

        writer = SummaryWriter(comment=f'{dataset_name}_{self.L}')

        logger = logging.getLogger('train')
        logger.setLevel(default_level)

        time_stamp = time_str()
        logging.basicConfig(  # unlike the root logger, a custom logger can’t be configured using basicConfig()
            filename=os.path.join(output_dir, f'{time_stamp}_{self.L}_labelled_instances.log'),
            format=format,
            datefmt="%m/%d/%Y %H:%M:%S",
            level=default_level)
        # to avoid double printing when creating new instances of class
        if not logger.handlers:
            console_handler = logging.StreamHandler(sys.stdout)
            console_handler.setLevel(default_level)
            console_handler.setFormatter(logging.Formatter(format))
            logger.addHandler(console_handler)
        #
        logger.info(dict(self.config._get_kwargs()))
        if self.device != 'cpu':
            logger.info(f'Device used: {self.device}_{torch.cuda.get_device_name(self.device)}')
        logger.info(f'Model:  {self.model.module.__class__ if isinstance(self.model, torch.nn.DataParallel) else self.model.__class__}')
        logger.info(f'Num_labels: {self.L}')
        logger.info(f'Image_size: {self.config.im_size}')
        logger.info(f'Cropsize: {self.config.cropsize}')
        logger.info("Total params: {:.2f}M".format(
            sum(p.numel() for p in self.model.parameters()) / 1e6))

        return logger, writer, time_stamp

    def _init_training(self, training_config, L):

        n_iters_per_epoch = training_config.n_imgs_per_epoch // training_config.batch_size
        n_iters_all = n_iters_per_epoch * training_config.n_epoches
        if training_config.seed > 0:
            torch.manual_seed(training_config.seed)
            random.seed(training_config.seed)
            np.random.seed(training_config.seed)

        self.logger.info("***** Running training *****")
        self.logger.info(f"  Num Epochs = {training_config.n_epoches}")
        self.logger.info(f"  Early Stopping Epochs Patience = "
                         f"{training_config.early_stopping_epochs if training_config.early_stopping_epochs else None}")
        self.logger.info(f"  Minimum Wait before ES = {training_config.min_wait_before_es} epochs")
        self.logger.info(f"  Batch size Labelled = {training_config.batch_size}")
        self.logger.info(f"  Total optimization steps = {n_iters_all}")

        return n_iters_per_epoch, n_iters_all

    def _get_optimiser(self, training_config):
        # set weight decay to zero for batch-norm layers
        wd_params, non_wd_params = [], []
        for name, param in self.model.named_parameters():
            if 'bn' in name:
                non_wd_params.append(param)  # bn.weight, bn.bias and classifier.bias
            else:
                wd_params.append(param)
        param_list = [{'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
        optim = torch.optim.SGD(param_list, lr=training_config.lr, weight_decay=training_config.weight_decay,
                                momentum=training_config.momentum, nesterov=True)

        return optim

    def _save_checkpoint(self):
        save_dir = 'saved_models' #os.path.abspath(os.path.join(self.config.checkpoint_path, os.pardir))
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        dataset_name = get_dataset_name(self.config.dataset_path)
        model_name = self.model.module._get_name() if isinstance(self.model, torch.nn.DataParallel) else self.model._get_name()
        checkpoint = {'ema_shadow':self.ema.shadow,
                      'model_state_dict': self.model.state_dict(),
                      'classes': self.dataset_meta['classes']}
        fpath = f'{save_dir}/{model_name}_{dataset_name}_{self.time_stamp}_checkpoint_dict.pth'
        torch.save(checkpoint,fpath)
        self.logger.info(f'Model Saved in: {fpath}')
Esempio n. 16
0
def train(args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args).to(device)

    D_batch = args.train_batch_size
    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    # writer = SummaryWriter(log_dir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1
    i = 0
    # iterator = data.train_iter
    while i + D_batch < len(data.data):
        b_id = i
        e_id = i + D_batch
        # present_epoch = int(iterator.epoch)
        # if present_epoch == args.epoch:
        #     break
        # if present_epoch > last_epoch:
        #     print('epoch:', present_epoch + 1)
        # last_epoch = present_epoch

        p1, p2 = model(data, b_id, e_id)

        optimizer.zero_grad()
        s_idx, e_idx = data.get_targ(b_id, e_id)
        batch_loss = criterion(p1, s_idx) + criterion(p2, e_idx)
        loss += batch_loss.item()
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        # if (i + 1) % args.print_freq == 0:
        #     dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
        #     c = (i + 1) // args.print_freq

        #     # writer.add_scalar('loss/train', loss, c)
        #     # writer.add_scalar('loss/dev', dev_loss, c)
        #     # writer.add_scalar('exact_match/dev', dev_exact, c)
        #     # writer.add_scalar('f1/dev', dev_f1, c)
        #     # print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
        #     #       f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

        #     if dev_f1 > max_dev_f1:
        #         max_dev_f1 = dev_f1
        #         max_dev_exact = dev_exact
        #         best_model = copy.deepcopy(model)

        #     loss = 0
        #     model.train()

        i += D_batch

    # writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model