コード例 #1
0
class VQA:
    def __init__(self):
        # Model
        self.model = VQAModel()

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)

        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa, self.model)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()

        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)

        self.valid_tuple = get_data_tuple(args.valid,
                                          bs=1024,
                                          shuffle=False,
                                          drop_last=False)

        self.test_tuple = get_data_tuple(args.test,
                                         bs=1024,
                                         shuffle=False,
                                         drop_last=False)

        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple, test_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        start_epoch = -1
        best_valid = 0.

        if args.RESUME:
            path_checkpoint = args.path_checkpoint  #  ../../model/checkpoint/lxr_best_%s.pth  # 断点路径
            checkpoint = torch.load(path_checkpoint)  # 加载断点

            self.model.load_state_dict(checkpoint['model_state_dict'])

            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']

        for epoch in range(start_epoch + 1, args.epochs):
            for i, (feats, boxes, sent, _,
                    _) in iter_wrapper(enumerate(loader)):

                # construct negative exmaples
                bs = len(sent)
                index_list = list(range(bs))

                sent_negative = []
                for j in range(bs):
                    choice = random.choice(list(set(index_list) - {j}))
                    sent_negative.append(sent[choice])

                sent = sent + sent_negative

                feats = torch.cat([feats, feats])
                boxes = torch.cat([boxes, boxes])

                target = torch.ones(bs, 1)
                target_negative = torch.zeros(bs, 1)
                target = torch.cat([target, target_negative])

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                batch_score = accuracy(logit, target)

                if i % 500 == 0:
                    print('epoch {}, Step {}/{}, Loss: {}'.format(
                        epoch, i, len(loader), loss.item()))

            log_str = "\nEpoch %d: Loss: %0.4f Train %0.2f\n" % (
                epoch, loss.item(), batch_score)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple, epoch)

                self.save_checkpoint(epoch)
                self.save(epoch)
                self.test_output(test_tuple, epoch)
                print('output done!')
                if valid_score > best_valid:
                    best_valid = valid_score

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score ) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid )

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, epoch=0):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        preds = []
        query_ids = []
        product_ids = []
        with torch.no_grad():
            for i, datum_tuple in enumerate(loader):
                feats, boxes, sent, ques_id, produce_id = datum_tuple  # Avoid seeing ground truth
                query_ids.extend(ques_id)
                product_ids.extend(produce_id)
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                logit = torch.sigmoid(logit)
                preds.append(logit.cpu().numpy())
        preds = np.concatenate(preds)

        # deal with format
        query2product = collections.defaultdict(list)
        for i, query_id in enumerate(query_ids):
            pred = preds[i]
            product_id = product_ids[i]
            query2product[str(query_id.item())].append(
                [pred.tolist()[0], str(product_id.item())])

        with open('../../user_data/lxmert_model/result/val/val_%s.txt' % epoch,
                  'w') as f:
            for q, p in query2product.items():
                q = str(q)
                p = str(p)
                f.write(q + ',' + p + '\n')
            f.close()

        with open('submission.csv', 'w') as f:
            f.write('query-id,product1,product2,product3,product4,product5\n')
            for q, p in query2product.items():
                p = sorted(p, key=lambda x: x[0], reverse=True)
                o = ','.join([p[i][1] for i in range(5)])
                f.write(q + ',' + o + '\n')

        os.system('python eval.py submission.csv')
        score = json.load(open('score.json'))["score"]
        return score

    def test_output(self, eval_tuple: DataTuple, epoch=0):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        preds = []
        query_ids = []
        product_ids = []
        with torch.no_grad():
            for i, datum_tuple in enumerate(loader):
                feats, boxes, sent, ques_id, produce_id = datum_tuple  # Avoid seeing ground truth
                query_ids.extend(ques_id)
                product_ids.extend(produce_id)
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                logit = torch.sigmoid(logit)
                preds.append(logit.cpu().numpy())
        preds = np.concatenate(preds)

        # deal with format
        query2product = collections.defaultdict(list)
        for i, query_id in enumerate(query_ids):
            pred = preds[i]
            product_id = product_ids[i]
            query2product[str(query_id.item())].append(
                [pred.tolist()[0], str(product_id.item())])

        print(os.getcwd())

        with open(
                '../../user_data/lxmert_model/result/test/test_%s.txt' % epoch,
                'w') as f:
            for q, p in query2product.items():
                q = str(q)
                p = str(p)
                f.write(q + ',' + p + '\n')
            f.close()

    def evaluate(self, eval_tuple: DataTuple, epoch=0):
        """Evaluate all data in data_tuple."""
        score = self.predict(eval_tuple, epoch)
        return score  #eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, epoch):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % (str(epoch))))

    def save_checkpoint(self, epoch):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
        }

        if not os.path.isdir('../../user_data/lxmert_model/checkpoint'):
            os.mkdir('../../user_data/lxmert_model/checkpoint')
        torch.save(checkpoint,
                   '../../user_data/lxmert_model/checkpoint/lxr_best.pth')

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
コード例 #2
0
    def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
        train_ld = train_tuple.loader

        # Optimizer
        from lxrt.optimization import BertAdam
        batch_per_epoch = len(train_ld)
        t_total = int(batch_per_epoch * args.epochs)
        warmup_ratio = 0.05
        warmup_iters = int(t_total * warmup_ratio)
        print("Batch per epoch: %d" % batch_per_epoch)
        print("Total Iters: %d" % t_total)
        print("Warm up Iters: %d" % warmup_iters)
        optim = BertAdam(self.model.parameters(),
                         lr=args.lr,
                         warmup=warmup_ratio,
                         t_total=t_total)
        start_epoch = 0

        if args.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            self.model, optim = amp.initialize(self.model,
                                               optim,
                                               opt_level='O1')

        # GPU Options
        if args.multiGPU:
            self.model = nn.DataParallel(self.model)

        if args.start_from > 0 and args.pretraining_index == 0:
            start_path = os.path.join(
                args.output,
                "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
            print('start training from {0}'.format(start_path))
            state = torch.load(start_path)
            self.model.load_state_dict(state['state_dict'])
            optim.load_state_dict(state['optimizer'], strict=False)
            start_epoch = args.start_from
            del state
            torch.cuda.empty_cache()
        elif args.start_from > 0 and args.pretraining_index == 1:
            start_path = os.path.join(
                args.output,
                "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
            print('start training from {0}'.format(start_path))
            state = torch.load(start_path)
            self.model.load_state_dict(state['state_dict'], strict=False)
            del state
            torch.cuda.empty_cache()

        # Train
        best_eval_loss = 9595.
        for epoch in range(start_epoch, args.epochs):
            # Train
            self.model.train()
            total_loss = 0.
            total_losses = 0.
            uid2ans = {}
            for batch in tqdm(train_ld, total=len(train_ld)):
                loss, losses, logit = self.train_batch(optim, batch)
                total_loss += loss
                total_losses += losses

                if args.task_qa:
                    score, label = logit.max(1)
                    for datum, l in zip(batch, label.cpu().numpy()):
                        uid = datum.uid
                        ans = train_tuple.dataset.answer_table.id2ans(l)
                        uid2ans[uid] = ans

            print("The training loss for Epoch %d is %0.4f" %
                  (epoch, total_loss / batch_per_epoch))
            log_str = "\nThe training loss for Epoch %d is %0.4f" % (
                epoch, total_loss / batch_per_epoch)
            losses_str = "\nThe losses are "
            log_str += "\nThe losses are "
            for name, loss in zip(LOSSES_NAME, total_losses):
                losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
                log_str += "\n %s: %0.4f " % (name, loss / batch_per_epoch)
            print(losses_str)
            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()
            if args.task_qa:
                train_tuple.evaluator.evaluate(uid2ans, pprint=True)

            # Eval
            avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)

            state = {
                'state_dict': self.model.state_dict(),
                'optimizer': optim.state_dict(),
            }

            # Save
            if avg_eval_loss < best_eval_loss:
                best_eval_loss = avg_eval_loss
                self.save("BEST_EVAL_LOSS", state)
            if args.pretraining_index == 0:
                self.save("Epoch%02d" % (epoch + 1), state)
            elif args.pretraining_index == 1:
                self.save("Epoch%02d" % (epoch + 1 + args.start_from), state)
コード例 #3
0
ファイル: lxmert_pretrain.py プロジェクト: uclanlp/visualbert
class LXMERT:
    def __init__(self, max_seq_length):
        super().__init__()
        self.max_seq_length = max_seq_length

        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        # Build model
        self.model = LXRTPretraining.from_pretrained(
            "bert-base-uncased",
            args = args,
            task_mask_lm=args.task_mask_lm,
            task_obj_predict=args.task_obj_predict,
            task_matched=args.task_matched,
            task_qa=args.task_qa,
            visual_losses=args.visual_losses,
            num_answers= args.num_answers if args.get("num_answers", None) else train_tuple.dataset.answer_table.num_answers
        )

        # Weight initialization and loading
        if args.from_scratch:
            print("Train from Scratch: re-initialize all BERT weights.")
            self.model.apply(self.model.init_bert_weights)

        if args.get("use_tag_symbolic_embedding", False):
            self.model.bert.embeddings.initialize_symbolic_embeddings(symbolic_vocab.get_symbolic_list(self.tokenizer))
            self.model.special_initialize_pretraining_head()
        
        if args.get("hybrid_embedding", False):
            self.model.bert.embeddings.initialize_visual_position_type_embeddings()
        
        if args.load_lxmert is not None:
            # Load lxmert would not load the answer head.
            self.load_lxmert(args.load_lxmert)
        
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model = nn.DataParallel(self.model)
        
        self.global_step = 0

    def forward(self, examples):
        
        for index, i in enumerate(examples):
            if i is not None:
                if isinstance(i, dict):
                    for key in i:
                        i[key] = (i[key][0].cuda(), i[key][1].cuda())
                else:
                    examples[index] = i.cuda()
        
        input_ids, segment_ids, input_mask, lm_labels, feats, pos, obj_labels, matched_labels, ans, visual_feats_seg_ids, visual_tags, visual_tags_mask, visual_tags_box, visual_tags_objective, visual_tags_mismatch, visual_tags_segment_ids = examples

        loss, losses, ans_logit, losses_dict = self.model(
            input_ids, segment_ids, input_mask, lm_labels,
            feats, pos, obj_labels, matched_labels, ans,
            visual_feats_seg_ids = visual_feats_seg_ids,
            visual_tags = visual_tags,
            visual_tags_mask = visual_tags_mask,
            visual_tags_box = visual_tags_box,
            visual_tags_objective = visual_tags_objective,
            visual_tags_mismatch = visual_tags_mismatch,
            visual_tags_segment_ids = visual_tags_segment_ids
        )
        return loss, losses.detach().cpu(), ans_logit, losses_dict

    def train_batch(self, optim, batch):
        
        gradient_accumulation_steps = args.get("gradient_accumulation_steps", 1)
        if (self.global_step + 1) % gradient_accumulation_steps == 0:
            optim.zero_grad()
        loss, losses, ans_logit, losses_dict = self.forward(batch)
        if args.multiGPU:
            loss = loss.mean()
            losses = losses.mean(0)
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()
        if (self.global_step + 1) % gradient_accumulation_steps == 0:
            nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
            optim.step()

        return loss.item(), losses.cpu().numpy(), ans_logit, losses_dict

    def valid_batch(self, batch):
        with torch.no_grad():
            loss, losses, ans_logit, losses_dict = self.forward(batch)
            if args.multiGPU:
                loss = loss.mean()
                losses = losses.mean(0)
        return loss.item(), losses.cpu().numpy(), ans_logit, losses_dict

    def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
        train_ld = train_tuple.loader

        # Optimizer
        from lxrt.optimization import BertAdam
        batch_per_epoch = len(train_ld)
        t_total = int(batch_per_epoch * args.epochs)
        warmup_ratio = args.get("warmup_ratio", 0.05)

        print("Total Iters: %d" % t_total)
        if args.get("t_total", None):
            t_total = args.t_total
            print("!! Changing to specified t_toal in args: {}".format(t_total))
        self.t_total = t_total
        warmup_iters = int(t_total * warmup_ratio)

        print("Batch per epoch: %d" % batch_per_epoch)
        print("Warm up Iters: %d" % warmup_iters)
        self.optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)

        if args.load is not None:
            self.load(args.load, t_total = t_total)

        gradient_accumulation_steps = args.get("gradient_accumulation_steps", 1)
        # Train
        best_eval_loss = 9595.
        report_every = args.get("report_every", 100)

        custom_train_meter = TrainingMeter()
        
        for epoch in range(args.epochs):
            # Train
            self.model.train()
            total_loss = 0.
            total_losses = 0.
            uid2ans = {}

            for batch_id, batch in enumerate(tqdm(train_ld, total=len(train_ld))):
                if args.get("skip_training", False):
                    break

                loss, losses, logit, losses_dict = self.train_batch(self.optim, batch)
                total_loss += loss
                try:
                    total_losses += losses
                except:
                    pass

                if args.task_qa and batch[0].sent is not None:
                    assert(0) # Not used in our experiment

                    score, label = logit.max(1)
                    for datum, l in zip(batch, label.cpu().numpy()):
                        uid = datum.uid
                        ans = train_tuple.dataset.answer_table.id2ans(l)
                        uid2ans[uid] = ans
                
                for key, value in losses_dict.items():
                    losses_dict[key] = value.mean().item()  # make the losses scalar
                
                if "Masked LM" in losses_dict and losses_dict["Masked LM"] == 0:
                    del losses_dict["Masked LM"]

                custom_train_meter.update(losses_dict)

                if batch_id % report_every == 0 and batch_id > 0:
                    print("Folder: {} \n Epoch {} Iter: {}/{}".format(args.output, epoch, batch_id, len(train_ld)))
                    #print(pd.DataFrame(train_results[-report_every:]).mean())
                    custom_train_meter.report()
                    custom_train_meter.clean()
                    print()
                
                if args.get("save_step", -1) != -1 and self.global_step != 0 and (self.global_step // gradient_accumulation_steps) % args.save_step == 0:
                    self.save("Step{}".format(self.global_step))
                self.global_step += 1
            
            print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))

            if args.task_qa:
                train_tuple.evaluator.evaluate(uid2ans, pprint=True)

            # Eval
            avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)

            if args.get("eval_on_train", False):
                print("On train set")
                self.evaluate_epoch(train_tuple, iters=-1)


            if avg_eval_loss < best_eval_loss:
                best_eval_loss = avg_eval_loss
                self.save("BEST_EVAL_LOSS")
            self.save("Epoch%02d" % (epoch+1))

    def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
        self.model.eval()
        eval_ld = eval_tuple.loader
        total_loss = 0.
        total_losses = 0.
        uid2ans = {}
        eval_meter = TrainingMeter()
        for i, batch in enumerate(tqdm(eval_ld)):
            loss, losses, logit, losses_dict = self.valid_batch(batch)
            total_loss += loss
            try:
                total_losses += losses
            except:
                pass
            for key, value in losses_dict.items():
                losses_dict[key] = value.mean().item()
            eval_meter.update(losses_dict)

            if args.task_qa:
                score, label = logit.max(1)
                for datum, l in zip(batch, label.cpu().numpy()):
                    uid = datum.uid
                    ans = train_tuple.dataset.answer_table.id2ans(l)
                    uid2ans[uid] = ans
            if i == iters:
                break
        print("Evaluation:")
        eval_meter.report()
        print("\n\n\n\n\n\n\n\n")

        if args.task_qa:
            eval_tuple.evaluator.evaluate(uid2ans, pprint=True)

        return total_loss / len(eval_ld)
    
    def evaluate_epoch_text(self, eval_tuple: DataTuple, iters: int=-1):
        self.model.eval()
        eval_ld = eval_tuple.textonly
        total_loss = 0.
        total_losses = 0.
        uid2ans = {}
        eval_meter = TrainingMeter()
        for i, batch in enumerate(tqdm(eval_ld)):
            loss, losses, logit, losses_dict = self.valid_batch(batch)
            total_loss += loss
            total_losses += losses
            for key, value in losses_dict.items():
                losses_dict[key] = value.mean().item()
            eval_meter.update(losses_dict)

            if args.task_qa:
                score, label = logit.max(1)
                for datum, l in zip(batch, label.cpu().numpy()):
                    uid = datum.uid
                    ans = train_tuple.dataset.answer_table.id2ans(l)
                    uid2ans[uid] = ans
            if i == iters:
                break
        print("Evaluation text only:")
        eval_meter.report()
        print("\n\n\n\n\n\n\n\n")

        return total_loss / len(eval_ld)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(args.output, "%s_LXRT.pth" % name))
        
        if args.get("save_optimizer", False) and "Step" not in name:
            torch.save(self.optim.state_dict(),
                   os.path.join(args.output, "%s_LXRT_optimizer.pth" % name))
        

    def load(self, path, t_total):
        print("Load model from %s" % path)
        state_dict = torch.load("%s_LXRT.pth" % path)
        #self.model.load_state_dict(state_dict)
        from qa_answer_table import load_state_dict_flexible
        load_state_dict_flexible(self.model, state_dict)

        optimizer_path = "{}_LXRT_optimizer.pth".format(path)
        if os.path.exists(optimizer_path) and args.get("load_optimizer", True):
            print("Load optimizer from {}".format(optimizer_path))

            loaded_optim = torch.load(optimizer_path)
            if args.get("reset_schedule", False):
                for group in loaded_optim["param_groups"]:
                    group['lr'] = args.lr
                    group['warmup'] = args.warmup_ratio
                    group["t_total"] = t_total

                    for p in group['params']:
                        loaded_optim["state"][p]["step"]
                        loaded_optim["state"][p]["step"] = 0
            self.optim.load_state_dict(loaded_optim)
    

    def load_lxmert(self, path):
        print("Load LXMERT model from %s" % path)
        state_dict = torch.load("%s_LXRT.pth" % path)

        # Do not load any answer head
        for key in list(state_dict.keys()):
            if 'answer' in key:
                state_dict.pop(key)

        # Change Multi GPU to single GPU
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("module."):
                new_state_dict[key[len("module."):]] = value
        state_dict = new_state_dict

        load_keys = set(state_dict.keys())
        model_keys = set(self.model.state_dict().keys())
        print()
        print("Keys in loaded but not in model:")
        for key in sorted(load_keys.difference(model_keys)):
            print(key)
        print()
        print("Keys in model but not in loaded:")
        for key in sorted(model_keys.difference(load_keys)):
            print(key)
        print()

        self.model.load_state_dict(state_dict, strict=False)