예제 #1
0
def main():
    from arguments.qgen_args import qgen_arguments
    from data_provider.qgen_baseline_dataset import prepare_dataset
    from process_data.tokenizer import GWTokenizer
    parser = qgen_arguments()
    args, _ = parser.parse_known_args()
    args = vars(args)
    tokenizer = GWTokenizer('./../data/dict.json')
    loader = prepare_dataset("./../data/", "test", args, tokenizer)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = QGenNetwork(args, tokenizer, device).to(device)
    optimizer = torch.optim.Adam(model.parameters(), args["lr"])
    data_iter = iter(loader)
    model.train()
    for i in range(5):
        batch = next(data_iter)
        optimizer.zero_grad()
        model.zero_grad()
        _, loss = model(batch)
        loss.backward()
        _ = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args["clip_val"])
        optimizer.step()
        print("loss: {:.4f}".format(loss.item()))
    model.eval()
    batch = next(data_iter)
    result, _ = model.generate(batch)
    print("generate")
    print(tokenizer.decode(result[0]))
예제 #2
0
def main():
    from arguments.oracle_args import oracle_arguments
    from data_provider.oracle_dataset import prepare_dataset
    from process_data.tokenizer import GWTokenizer
    from utils.calculate_util import calculate_accuracy
    parser = oracle_arguments()
    args, _ = parser.parse_known_args()
    args = vars(args)
    tokenizer = GWTokenizer('./../data/dict.json')
    loader = prepare_dataset("./../data/", "test", args, tokenizer)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = OracleNetwork(args, tokenizer, device).to(device)
    optimizer = torch.optim.Adam(model.parameters(), args["lr"])
    data_iter = iter(loader)
    model.train()
    for i in range(20):
        batch = next(data_iter)
        optimizer.zero_grad()
        model.zero_grad()
        output = model(batch)
        target = batch[-1].to(device).long()  # target object index
        loss = torch.nn.functional.cross_entropy(output, target)
        loss.backward()
        _ = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args["clip_val"])
        optimizer.step()
        print("loss: {:.4f}".format(loss.item()))
    model.eval()
    batch = next(data_iter)
    output = model(batch)
    target = batch[-1].to(device).long()
    _, accuracy = calculate_accuracy(output, target)
    print("acc: {:4f}".format(accuracy))
예제 #3
0
def main():
    from arguments.qgen_args import qgen_arguments
    from process_data.tokenizer import GWTokenizer
    data_dir = "./../data/"
    tokenizer = GWTokenizer('./../data/dict.json')
    parser = qgen_arguments()
    args = parser.parse_args()
    args = vars(args)
    dataset = QuestionDataset(data_dir, 'test', args, tokenizer=tokenizer)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=8, collate_fn=question_collate)
    print(len(dataset), len(dataloader))
    dataiter = iter(dataloader)
    for i in range(1):
        batch = next(dataiter)
예제 #4
0
파일: train_qgen.py 프로젝트: tobytyx/CSQG
def main(args):
    param_file = save_path.format("params.pth.tar")
    data_dir = "./../data/"
    model_name = args["model"].lower()
    tokenizer = GWTokenizer('./../data/dict.json')
    if model_name == "cat_base":
        from models.qgen.qgen_cat_base import QGenNetwork
        from data_provider.qgen_dataset import prepare_dataset
    elif model_name == "hrnn":
        from models.qgen.qgen_hrnn import QGenNetwork
        from data_provider.qgen_dataset import prepare_dataset
    elif model_name == "cat_accu":
        from models.qgen.qgen_cat_accu import QGenNetwork
        from data_provider.qgen_dataset import prepare_dataset
    elif model_name == "cat_attn":
        from models.qgen.qgen_cat_attn import QGenNetwork
        from data_provider.qgen_dataset import prepare_dataset
    else:
        print(model_name)
        from models.qgen.qgen_baseline import QGenNetwork
        from data_provider.qgen_baseline_dataset import prepare_dataset
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args["option"] == "train":
        if args["image"] is False and args["object"] is False:
            print("default object")
            args["object"] = True
            args["image_arch"] = "rcnn"
            args["image_dim"] = 2048
        with open(save_path.format("args.json"), mode="w") as f:
            json.dump(args, f, indent=2, ensure_ascii=False)
        logger.info(args)
        model = QGenNetwork(args, tokenizer, device).to(device)
        train_loader, val_loader = prepare_dataset(data_dir, "train", args,
                                                   tokenizer)
        train(model, args, train_loader, val_loader, param_file)
    else:
        with open(save_path.format("args.json"), mode="r") as f:
            saved_args = json.load(f)
            saved_args["option"] = "test"
        args = saved_args
        logger.info(args)
        model = QGenNetwork(args, tokenizer, device).to(device)
        testloader = prepare_dataset(data_dir, "test", args, tokenizer)
        test(model, args, testloader, param_file)
예제 #5
0
def main(args):
    param_file = save_path.format("params.pth.tar")
    data_dir = "./../data/"
    tokenizer = GWTokenizer('./../data/dict.json')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = OracleNetwork(args, tokenizer, device).to(device)
    if args["option"] == "train":
        with open(save_path.format("args.json"), mode="w") as f:
            json.dump(args, f, indent=2, ensure_ascii=False)
        logger.info(args)
        train_loader, val_loader = prepare_dataset(data_dir, "train", args, tokenizer)
        train(model, args, train_loader, val_loader, param_file)
    else:
        with open(save_path.format("args.json"), mode="r") as f:
            saved_args = json.load(f)
            saved_args["option"] = "test"
        args = saved_args
        logger.info(args)
        test_loader = prepare_dataset(data_dir, "test", args, tokenizer)
        test(model, test_loader, param_file)
예제 #6
0
def main():
    parser = looper_arguments()
    args, _ = parser.parse_known_args()
    args = vars(args)
    print(args)
    data_dir = "./../data"
    tokenizer = GWTokenizer("./../data/dict.json")
    out_dir = "./../out/"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    oracle = OracleWrapper(data_dir, out_dir, args["oracle_name"], tokenizer,
                           device)
    guess = GuesserWrapper(data_dir, out_dir, args["guesser_name"], tokenizer,
                           device)
    question = QuestionWrapper(data_dir, out_dir, args["qgen_name"], tokenizer,
                               device)
    loop = Looper(data_dir=data_dir,
                  oracle=oracle,
                  guesser=guess,
                  question=question,
                  args=args)
    _, success_rate = loop.eval(option=args["option"],
                                out_dir=out_dir,
                                store=not args["no_store"],
                                visualize=args["visualize"])
    result_file = os.path.join(out_dir, "games", "test.json")
    turns = "{}turns".format(args["max_turn"])
    models_name = ",".join([
        args["option"], turns, args["qgen_name"], args["oracle_name"],
        args["guesser_name"]
    ])
    print("model_name: ", models_name)
    res = {}
    if os.path.exists(result_file):
        with open(result_file, mode="r") as f:
            res = json.load(f)
    res[models_name] = round(success_rate, 4)
    with open(result_file, mode="w") as f:
        json.dump(res, f, indent=2)
예제 #7
0
    def __init__(self, data_dir, args, device, logger):
        oracle_name, guesser_name, qgen_name = args["oracle_name"], args[
            "guesser_name"], args["qgen_name"]
        self.data_dir = data_dir
        self.device = device
        self.args = args
        self.tokenizer = GWTokenizer("./../data/dict.json")
        self.logger = logger
        with open(os.path.join("../out/oracle", oracle_name,
                               "args.json")) as f:
            self.oracle_args = json.load(f)
        with open(os.path.join("../out/qgen", qgen_name, "args.json")) as f:
            self.qgen_args = json.load(f)
        with open(os.path.join("../out/guesser", guesser_name,
                               "args.json")) as f:
            self.guesser_args = json.load(f)
        self.guesser_model = GuesserNetwork(self.guesser_args, self.tokenizer,
                                            self.device).to(self.device)
        self.guesser_model.load_state_dict(
            torch.load(
                os.path.join("../out/guesser", guesser_name,
                             "params.pth.tar")))
        self.guesser_model.eval()

        self.oracle_model = OracleNetwork(self.oracle_args, self.tokenizer,
                                          self.device).to(self.device)
        self.oracle_model.load_state_dict(
            torch.load(
                os.path.join("../out/oracle", oracle_name, "params.pth.tar")))
        self.oracle_model.eval()
        self.answer_dict = {0: 'Yes', 1: 'No', 2: 'N/A'}
        self.token_to_answer_idx = {
            self.tokenizer.yes_token: 0,
            self.tokenizer.no_token: 1,
            self.tokenizer.non_applicable_token: 2
        }
        if self.qgen_args["model"] == "cat_base":
            from models.qgen.qgen_cat_base import QGenNetwork
        elif self.qgen_args["model"] == "cat_accu":
            from models.qgen.qgen_cat_accu import QGenNetwork
        elif self.qgen_args["model"] == "hrnn":
            from models.qgen.qgen_hrnn import QGenNetwork
        else:
            from models.qgen.qgen_baseline import QGenNetwork
        self.qgen_model = QGenNetwork(args=self.qgen_args,
                                      tokenizer=self.tokenizer,
                                      device=self.device).to(self.device)
        self.qgen_model.load_state_dict(torch.load(
            os.path.join("../out/qgen", qgen_name, "params.pth.tar")),
                                        strict=False)
        train_dataset = LoopRlDataset(data_dir, "train", "train",
                                      self.qgen_args, self.oracle_args,
                                      self.guesser_args, self.tokenizer)
        # train_dataset.games = train_dataset.games[:100]
        self.train_loader = DataLoader(train_dataset,
                                       num_workers=4,
                                       collate_fn=loop_rl_collate,
                                       shuffle=True,
                                       batch_size=args["batch_size"])
        self.gt_qas = {}
        for game in train_dataset.games:
            self.gt_qas[game.id] = {
                "questions": game.questions,
                "answers": game.answers
            }

        if args["option"] == "new_objects":
            print("new objects")
            self.test_loader = DataLoader(LoopRlDataset(
                data_dir, "train", "test", self.qgen_args, self.oracle_args,
                self.guesser_args, self.tokenizer),
                                          num_workers=4,
                                          collate_fn=loop_rl_collate,
                                          shuffle=False,
                                          batch_size=args["batch_size"])
        else:
            print("new games")
            self.test_loader = DataLoader(LoopRlDataset(
                data_dir, "test", "test", self.qgen_args, self.oracle_args,
                self.guesser_args, self.tokenizer),
                                          num_workers=4,
                                          collate_fn=loop_rl_collate,
                                          shuffle=False,
                                          batch_size=args["batch_size"])
예제 #8
0
class Looper(object):
    def __init__(self, data_dir, args, device, logger):
        oracle_name, guesser_name, qgen_name = args["oracle_name"], args[
            "guesser_name"], args["qgen_name"]
        self.data_dir = data_dir
        self.device = device
        self.args = args
        self.tokenizer = GWTokenizer("./../data/dict.json")
        self.logger = logger
        with open(os.path.join("../out/oracle", oracle_name,
                               "args.json")) as f:
            self.oracle_args = json.load(f)
        with open(os.path.join("../out/qgen", qgen_name, "args.json")) as f:
            self.qgen_args = json.load(f)
        with open(os.path.join("../out/guesser", guesser_name,
                               "args.json")) as f:
            self.guesser_args = json.load(f)
        self.guesser_model = GuesserNetwork(self.guesser_args, self.tokenizer,
                                            self.device).to(self.device)
        self.guesser_model.load_state_dict(
            torch.load(
                os.path.join("../out/guesser", guesser_name,
                             "params.pth.tar")))
        self.guesser_model.eval()

        self.oracle_model = OracleNetwork(self.oracle_args, self.tokenizer,
                                          self.device).to(self.device)
        self.oracle_model.load_state_dict(
            torch.load(
                os.path.join("../out/oracle", oracle_name, "params.pth.tar")))
        self.oracle_model.eval()
        self.answer_dict = {0: 'Yes', 1: 'No', 2: 'N/A'}
        self.token_to_answer_idx = {
            self.tokenizer.yes_token: 0,
            self.tokenizer.no_token: 1,
            self.tokenizer.non_applicable_token: 2
        }
        if self.qgen_args["model"] == "cat_base":
            from models.qgen.qgen_cat_base import QGenNetwork
        elif self.qgen_args["model"] == "cat_accu":
            from models.qgen.qgen_cat_accu import QGenNetwork
        elif self.qgen_args["model"] == "hrnn":
            from models.qgen.qgen_hrnn import QGenNetwork
        else:
            from models.qgen.qgen_baseline import QGenNetwork
        self.qgen_model = QGenNetwork(args=self.qgen_args,
                                      tokenizer=self.tokenizer,
                                      device=self.device).to(self.device)
        self.qgen_model.load_state_dict(torch.load(
            os.path.join("../out/qgen", qgen_name, "params.pth.tar")),
                                        strict=False)
        train_dataset = LoopRlDataset(data_dir, "train", "train",
                                      self.qgen_args, self.oracle_args,
                                      self.guesser_args, self.tokenizer)
        # train_dataset.games = train_dataset.games[:100]
        self.train_loader = DataLoader(train_dataset,
                                       num_workers=4,
                                       collate_fn=loop_rl_collate,
                                       shuffle=True,
                                       batch_size=args["batch_size"])
        self.gt_qas = {}
        for game in train_dataset.games:
            self.gt_qas[game.id] = {
                "questions": game.questions,
                "answers": game.answers
            }

        if args["option"] == "new_objects":
            print("new objects")
            self.test_loader = DataLoader(LoopRlDataset(
                data_dir, "train", "test", self.qgen_args, self.oracle_args,
                self.guesser_args, self.tokenizer),
                                          num_workers=4,
                                          collate_fn=loop_rl_collate,
                                          shuffle=False,
                                          batch_size=args["batch_size"])
        else:
            print("new games")
            self.test_loader = DataLoader(LoopRlDataset(
                data_dir, "test", "test", self.qgen_args, self.oracle_args,
                self.guesser_args, self.tokenizer),
                                          num_workers=4,
                                          collate_fn=loop_rl_collate,
                                          shuffle=False,
                                          batch_size=args["batch_size"])

    def rl_train_epoch(self, optimizer):
        self.qgen_model.train()
        total_reward_loss = 0
        total_reward = 0
        last_reward = 0
        last_reward_loss = 0
        last_step = 0
        # total_baseline_loss = 0
        steps = 1
        for batch in self.train_loader:
            _, q_imgs, q_bbox, o_imgs, o_crops, o_cats, o_spas, g_imgs, g_obj_mask, g_cats, g_spas, targets = batch
            self.qgen_model.zero_grad()
            optimizer.zero_grad()
            reward, reward_loss = self.rl_step(q_imgs, q_bbox, o_imgs, o_crops,
                                               o_cats, o_spas, g_imgs,
                                               g_obj_mask, g_cats, g_spas,
                                               targets)
            reward_loss.backward()
            optimizer.step()

            total_reward_loss += reward_loss.item()
            total_reward += reward.mean().item()
            # total_baseline_loss += baseline_loss.item()
            if steps % self.args["log_step"] == 0:
                log_loss = (total_reward_loss - last_reward_loss) / (steps -
                                                                     last_step)
                log_reward = (total_reward - last_reward) / (steps - last_step)
                self.logger.info("Step {}, Loss {:.4f}, Reward: {}".format(
                    steps, log_loss, log_reward))
                last_step = steps
                last_reward = total_reward
                last_reward_loss = total_reward_loss
            steps += 1
        # total_baseline_loss /= (len(self.train_loader) * self.args["max_turn"])
        total_reward_loss /= len(self.train_loader)
        return total_reward_loss

    def rl_step(self, q_imgs, q_bbox, o_imgs, o_crops, o_cats, o_spas, g_imgs,
                g_obj_mask, g_cats, g_spas, targets):
        targets = targets.to(self.device)
        bsz = q_imgs.size(0)
        question = torch.ones(bsz, 1, dtype=torch.long,
                              device=self.device) * constants.EOS
        answer = torch.tensor([2] * bsz, dtype=torch.long, device=self.device)
        dials, dial_lens = None, None
        questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
            question, answer)
        total_gen_loss = []
        for turn in range(self.args["max_turn"]):
            turns = torch.tensor([q_indexes.size(1)] * bsz,
                                 dtype=torch.long,
                                 device=self.device)
            # qgen
            baseline, gen_loss, category, question = self.qgen_model.pg_forward(
                questions, qs_lens, q_indexes, answers, categories, turns,
                q_imgs, q_bbox)
            gen_loss = torch.sum(gen_loss, dim=1) * (0.9**turn)
            total_gen_loss.append(gen_loss)
            category, question = category.detach(), question.detach()
            # oracle
            q_lens = torch.sum(torch.ne(question, 0) *
                               torch.ne(question, constants.EOS),
                               dim=1).to(dtype=torch.long, device=self.device)
            batch = [question, q_lens, o_imgs, o_crops, o_cats, o_spas]
            answer = self.oracle_model(batch)
            answer = torch.argmax(answer, dim=1).detach()
            # guesser
            if turn == 0:
                dials, dial_lens = self.prepare_dialogues(question, answer)
            else:
                dials, dial_lens = self.prepare_dialogues(
                    question, answer, dials, dial_lens)

            if turn == 0:
                questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                    question, answer, category)
            else:
                questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                    question, answer, category, questions, qs_lens, q_indexes,
                    answers, categories)
        batch = [dials, dial_lens, g_imgs, g_obj_mask, g_cats, g_spas]
        predict = self.guesser_model(batch)
        predict = torch.argmax(predict, dim=1)
        # loss 没有log,每一轮的loss直接相加有点问题。完全平均不太对。越远的地方reward应该越强。
        reward = (predict == targets).to(dtype=torch.float,
                                         device=self.device).detach()
        # reward_score = torch.norm(reward - baseline) * gen_loss
        # reward_loss = torch.mean(torch.sum(reward_score, dim=1), dim=0)
        total_gen_loss = torch.stack(total_gen_loss)
        reward_loss = torch.mean(reward * torch.sum(total_gen_loss, dim=1),
                                 dim=0)
        # baseline_loss = torch.sum(torch.norm(reward - baseline))
        # loss = reward_loss + baseline_loss
        return reward, reward_loss

    def eval(self):
        self.qgen_model.eval()
        failed = {}
        success_dials = {}
        success_num, total_num = 0, 0
        with torch.no_grad():
            for batch in self.test_loader:
                game_ids, q_imgs, q_bbox, o_imgs, o_crops, o_cats, o_spas, g_imgs, g_obj_mask, g_cats, g_spas, targets = batch
                bsz = q_imgs.size(0)
                targets = targets.to(self.device)
                question = torch.ones(
                    bsz, 1, dtype=torch.long,
                    device=self.device) * constants.EOS
                answer = torch.tensor([2] * bsz,
                                      dtype=torch.long,
                                      device=self.device)
                dials, dial_lens = self.prepare_dialogues(question, answer)
                questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                    question, answer)
                for turn in range(self.args["max_turn"]):
                    turns = torch.tensor([q_indexes.size(1)] * bsz,
                                         dtype=torch.long,
                                         device=self.device)
                    # print(questions.size(), qs_lens.size())
                    try:
                        _, _, category, question = self.qgen_model.pg_forward(
                            questions, qs_lens, q_indexes, answers, categories,
                            turns, q_imgs, q_bbox)
                    except Exception as e:
                        print(e)
                        questions = questions.detach().cpu().tolist()
                        qs_lens = qs_lens.detach().cpu().tolist()
                        with open("debug.json", mode="w") as f:
                            json.dump(
                                {
                                    "questions": questions,
                                    "qs_lens": qs_lens
                                }, f)
                        raise
                    q_lens = torch.sum(torch.ne(question, 0) *
                                       torch.ne(question, constants.EOS),
                                       dim=1).to(dtype=torch.long,
                                                 device=self.device)
                    batch = [question, q_lens, o_imgs, o_crops, o_cats, o_spas]
                    answer = self.oracle_model(batch)
                    answer = torch.argmax(answer, dim=1)
                    if turn == 0:
                        dials, dial_lens = self.prepare_dialogues(
                            question, answer)
                        questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                            question, answer, category)
                    else:
                        dials, dial_lens = self.prepare_dialogues(
                            question, answer, dials, dial_lens)
                        questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                            question, answer, category, questions, qs_lens,
                            q_indexes, answers, categories)
                batch = [dials, dial_lens, g_imgs, g_obj_mask, g_cats, g_spas]
                predict = self.guesser_model(batch)
                predict = torch.argmax(predict, dim=1)
                success = (predict == targets).to(torch.long)
                dials = dials.detach().cpu().tolist()
                dial_lens = dial_lens.detach().cpu().tolist()
                for i in range(success.size(0)):
                    if success[i].item() == 0:
                        game_id = game_ids[i]
                        dial = self.tokenizer.decode(dials[i][:dial_lens[i]])
                        failed[game_id] = dial
                    else:
                        game_id = game_ids[i]
                        dial = self.tokenizer.decode(dials[i][:dial_lens[i]])
                        success_dials[game_id] = dial
                success_num += torch.sum(success).item()
                total_num += predict.size(0)

        success_rate = success_num / total_num
        return success_rate, failed, success_dials

    def prepare_qgen(self,
                     question,
                     answers,
                     category=None,
                     pre_questions=None,
                     pre_qs_lens=None,
                     pre_q_indexes=None,
                     pre_answers=None,
                     pre_categories=None):
        """

        :param question: B * len
        :param answers: B
        :param category: B / B * cate_len / None
        :param pre_questions: B * total_len
        :param pre_qs_lens: B
        :param pre_q_indexes: B * turn
        :param pre_answers: B * turn
        :param pre_categories: B * turn / B * turn * cate_len / None
        :return:
        """
        bsz = question.size(0)
        multi_cate = self.qgen_args["multi_cate"]
        cur_len = torch.sum(torch.ne(question, 0) *
                            torch.ne(question, constants.EOS),
                            dim=1).to(dtype=torch.long, device=self.device)
        if pre_questions is None:
            max_len = torch.max(cur_len).item()
            cur_questions = torch.zeros(bsz,
                                        max_len + 1,
                                        dtype=torch.long,
                                        device=self.device)
            for i in range(bsz):
                qs_len = cur_len[i].item()
                cur_questions[i, :qs_len] = question[i, :qs_len]
                cur_questions[i, qs_len] = constants.EOS
            cur_qs_lens = cur_len + 1
            cur_q_indexes = cur_len.unsqueeze(1)
        else:
            max_len = torch.max(pre_qs_lens + cur_len).item()
            cur_questions = torch.zeros(bsz,
                                        max_len,
                                        dtype=torch.long,
                                        device=self.device)
            for i in range(bsz):
                pre_qs_len = pre_qs_lens[i].item()
                qs_len = cur_len[i].item()
                cur_questions[i, :pre_qs_len -
                              1] = pre_questions[i, :pre_qs_len - 1]
                cur_questions[i, pre_qs_len - 1:pre_qs_len + qs_len -
                              1] = question[i, :qs_len]
                cur_questions[i, pre_qs_len + qs_len - 1] = constants.EOS
            cur_qs_lens = pre_qs_lens + cur_len
            cur_q_index = pre_q_indexes[:, -1] + cur_len
            cur_q_indexes = torch.cat(
                [pre_q_indexes, cur_q_index.unsqueeze(1)], dim=1)
        cur_answers = answers.unsqueeze(1)
        if pre_answers is not None:
            cur_answers = torch.cat([pre_answers, cur_answers], dim=1)

        if category is None:
            if multi_cate:
                category = torch.tensor([[0, 0, 0, 0] for _ in range(bsz)],
                                        dtype=torch.float,
                                        device=self.device)
            else:
                category = torch.tensor([3] * bsz,
                                        dtype=torch.long,
                                        device=self.device)
        cur_categories = category.unsqueeze(1)
        if pre_categories is not None:
            cur_categories = torch.cat([pre_categories, cur_categories], dim=1)
        return cur_questions, cur_qs_lens, cur_q_indexes, cur_answers, cur_categories

    def prepare_dialogues(self,
                          question,
                          answer,
                          pre_dials=None,
                          pre_dial_lens=None):
        """

        :param question: B * len
        :param answer: B
        :param pre_dials: B * total_len
        :param pre_dial_lens: B
        :return:
        """
        bsz = question.size(0)
        cur_q_len = torch.sum(torch.ne(question, 0) *
                              torch.ne(question, constants.EOS),
                              dim=1).to(dtype=torch.long, device=self.device)
        answer_token = [self.answer_dict[a] for a in answer.cpu().tolist()]
        answer = []
        for token in answer_token:
            answer.extend(self.tokenizer.apply(token, is_answer=True))

        if pre_dials is None:
            max_dial_len = torch.max(cur_q_len).item() + 1
            dials = torch.zeros(bsz,
                                max_dial_len,
                                dtype=torch.long,
                                device=self.device)
            dial_lens = cur_q_len + 1
            for i in range(bsz):
                q_len = cur_q_len[i].item()
                dials[i, :q_len] = question[i, :q_len]
                dials[i, q_len] = answer[i]
        else:
            max_dial_len = torch.max(cur_q_len + pre_dial_lens).item() + 1
            dials = torch.zeros(bsz,
                                max_dial_len,
                                dtype=torch.long,
                                device=self.device)
            dial_lens = pre_dial_lens + cur_q_len + 1
            for i in range(bsz):
                pre_dial_len = pre_dial_lens[i].item()
                q_len = cur_q_len[i].item()
                dials[i, :pre_dial_len] = pre_dials[i, :pre_dial_len]
                dials[i,
                      pre_dial_len:pre_dial_len + q_len] = question[i, :q_len]
                dials[i, pre_dial_len + q_len] = answer[i]
        return dials, dial_lens

    def rl_sampling(self, game_ids, q_imgs, q_bbox, o_imgs, o_crops, o_cats,
                    o_spas, g_imgs, g_obj_mask, g_cats, g_spas, targets):
        sampling_dials = {}
        with torch.no_grad():
            targets = targets.to(self.device)
            bsz = q_imgs.size(0)
            question = torch.ones(bsz, 1, dtype=torch.long,
                                  device=self.device) * constants.EOS
            answer = torch.tensor([2] * bsz,
                                  dtype=torch.long,
                                  device=self.device)
            dials, dial_lens = None, None
            questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                question, answer)
            for turn in range(self.args["max_turn"]):
                turns = torch.tensor([q_indexes.size(1)] * bsz,
                                     dtype=torch.long,
                                     device=self.device)
                # qgen
                _, _, category, question = self.qgen_model.pg_forward(
                    questions, qs_lens, q_indexes, answers, categories, turns,
                    q_imgs, q_bbox)
                category, question = category.detach(), question.detach()
                # oracle
                q_lens = torch.sum(torch.ne(question, 0) *
                                   torch.ne(question, constants.EOS),
                                   dim=1).to(dtype=torch.long,
                                             device=self.device)
                q_lens_zero = (q_lens <= 0).to(dtype=torch.long,
                                               device=self.device)
                q_lens = q_lens + q_lens_zero
                batch = [question, q_lens, o_imgs, o_crops, o_cats, o_spas]
                answer = self.oracle_model(batch)
                answer = torch.argmax(answer, dim=1).detach()
                # guesser
                if turn == 0:
                    dials, dial_lens = self.prepare_dialogues(question, answer)
                else:
                    dials, dial_lens = self.prepare_dialogues(
                        question, answer, dials, dial_lens)

                if turn == 0:
                    questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                        question, answer, category)
                else:
                    questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                        question, answer, category, questions, qs_lens,
                        q_indexes, answers, categories)
            batch = [dials, dial_lens, g_imgs, g_obj_mask, g_cats, g_spas]
            predict = self.guesser_model(batch)
            predict = torch.argmax(predict, dim=1)
            success = (predict == targets).to(torch.long)
            dials = dials.detach().cpu().tolist()
            dial_lens = dial_lens.detach().cpu().tolist()

            # B * turn
            categories = categories.detach().cpu().tolist()
            for i in range(success.size(0)):
                game_id = game_ids[i]
                dial = dials[i][:dial_lens[i]]
                category = categories[i]
                questions, answers = split_dial(dial, self.token_to_answer_idx)
                sampling_dials[game_id] = {
                    "questions": questions,
                    "answers": answers,
                    "successes": success[i].item(),
                    "categories": category
                }
        return sampling_dials

    def rl_reward_step(self, q_imgs, q_bbox, all_questions, all_answers,
                       successes, all_categories):
        bsz = q_imgs.size(0)
        question = torch.ones(bsz, 1, dtype=torch.long,
                              device=self.device) * constants.EOS
        answer = torch.tensor([2] * bsz, dtype=torch.long, device=self.device)
        questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
            question, answer)
        total_gen_loss = []
        total_cls_loss = []
        sos, eos, pad = self.tokenizer.start_token, self.tokenizer.stop_token, self.tokenizer.padding_token
        for turn in range(self.args["max_turn"]):
            turns = torch.tensor([q_indexes.size(1)] * bsz,
                                 dtype=torch.long,
                                 device=self.device)
            target_cate, targets = None, None
            if self.args["cate_rl"]:
                target_cate = all_categories[turn]

            targets = all_questions[turn]
            max_len = max([len(each) for each in targets])
            targets = [[sos] + target + [eos] + [pad] * (max_len - len(target))
                       for target in targets]
            targets = torch.tensor(targets,
                                   dtype=torch.long,
                                   device=self.device)

            cls_loss, gen_loss = self.qgen_model.pg_forward_with_target(
                questions, qs_lens, q_indexes, answers, categories, turns,
                q_imgs, q_bbox, target_cate, targets)
            if isinstance(gen_loss, torch.Tensor):
                gen_loss = torch.sum(gen_loss, dim=1)
            total_gen_loss.append(gen_loss)
            total_cls_loss.append(cls_loss)
            question = targets[:, 1:].detach()
            answer = torch.tensor(all_answers[turn],
                                  dtype=torch.long,
                                  device=self.device)
            if self.qgen_args["multi_cate"]:
                category = torch.tensor(all_categories[turn],
                                        dtype=torch.float,
                                        device=self.device)
            else:
                category = torch.tensor(all_categories[turn],
                                        dtype=torch.long,
                                        device=self.device)
            if turn == 0:
                questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                    question, answer, category)
            else:
                questions, qs_lens, q_indexes, answers, categories = self.prepare_qgen(
                    question, answer, category, questions, qs_lens, q_indexes,
                    answers, categories)
        reward_gen_loss = 0
        reward_cls_loss = 0
        if self.args["gen_rl"]:
            total_gen_loss = torch.stack(total_gen_loss)  # turn * B
        if self.args["cate_rl"]:
            total_cls_loss = torch.stack(total_cls_loss)  # turn * B
        # total_gen_loss = torch.sum(total_gen_loss, dim=0)
        for i in range(bsz):
            gen_loss = 0
            cls_loss = 0
            for j in range(self.args["max_turn"]):
                for k in range(self.args["max_turn"] - j):
                    if self.args["gen_rl"]:
                        gen_loss = gen_loss + total_gen_loss[j, i] * (0.9**k)
                    if self.args["cate_rl"]:
                        cls_loss = cls_loss + total_cls_loss[j, i] * (0.9**k)
            if successes[i] == 1:
                reward_gen_loss = reward_gen_loss + gen_loss * 0.9
                reward_cls_loss = reward_cls_loss + cls_loss * 0.9
            else:
                reward_gen_loss = reward_gen_loss + gen_loss * -0.01
                reward_cls_loss = reward_cls_loss + cls_loss * -0.01
        reward_gen_loss /= bsz
        reward_cls_loss /= bsz
        return reward_cls_loss, reward_gen_loss

    def rl_sample_reward_epoch(self, optimizer):
        sample_rate = self.args["sample_rate"]
        sample_rate = max(0, min(1, sample_rate))
        # 采样
        sample_dials = {}
        self.qgen_model.eval()
        for batch in self.train_loader:
            game_ids, q_imgs, q_bbox, o_imgs, o_crops, o_cats, o_spas, g_imgs, g_obj_mask, g_cats, g_spas, targets = batch
            sampling_dial = self.rl_sampling(game_ids, q_imgs, q_bbox, o_imgs,
                                             o_crops, o_cats, o_spas, g_imgs,
                                             g_obj_mask, g_cats, g_spas,
                                             targets)
            sample_dials.update(sampling_dial)
        # 更新
        self.qgen_model.train()
        total_reward_loss = 0
        total_gen_loss, last_gen_loss = 0, 0
        total_cls_loss, last_cls_loss = 0, 0
        last_step, steps = 0, 1
        for batch in self.train_loader:
            if random.random() > sample_rate:
                continue
            game_ids, q_imgs, q_bbox, *_ = batch
            questions, answers, categories, successes = [], [], [], []
            for game_id in game_ids:
                dial = sample_dials[game_id]
                questions.append(dial["questions"])
                answers.append(dial["answers"])
                categories.append(dial["categories"])
                successes.append(dial["successes"])
            all_questions = list(zip(*questions))
            all_answers = list(zip(*answers))
            all_categories = list(zip(*categories))
            self.qgen_model.zero_grad()
            optimizer.zero_grad()
            reward_cls_loss, reward_gen_loss = self.rl_reward_step(
                q_imgs, q_bbox, all_questions, all_answers, successes,
                all_categories)
            reward_loss = 0
            if self.args["cate_rl"]:
                reward_loss += reward_cls_loss
                total_cls_loss += reward_cls_loss.item()
            if self.args["gen_rl"]:
                reward_loss += reward_gen_loss
                total_gen_loss += reward_gen_loss.item()
            reward_loss.backward()
            optimizer.step()
            total_reward_loss += reward_loss.item()
            if steps % self.args["log_step"] == 0:
                gen_loss = (total_gen_loss - last_gen_loss) / (steps -
                                                               last_step)
                cls_loss = (total_cls_loss - last_cls_loss) / (steps -
                                                               last_step)
                self.logger.info(
                    "Step {}, Gen Loss {:.4f}, Cls Loss {:.4f}".format(
                        steps, gen_loss, cls_loss))
                last_step = steps
                last_gen_loss = total_gen_loss
                last_cls_loss = total_cls_loss
            steps += 1
        # total_baseline_loss /= (len(self.train_loader) * self.args["max_turn"])
        total_reward_loss /= steps
        return total_reward_loss
예제 #9
0
import torch
import os
import json
from utils import constants
from models.guesser.baseline_model import GuesserNetwork
from process_data.tokenizer import GWTokenizer
from data_provider.gw_dataset import ImageProvider
from data_provider.loop_rl_dataset import get_bbox

qgen_name = "cat_v_rcnn_cls_prior_cate"
guesser_name = "baseline"
data_dir = "./../data"

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
tokenizer = GWTokenizer("./../data/dict.json")
with open(os.path.join("../out/qgen", qgen_name, "args.json")) as f:
    qgen_args = json.load(f)
with open(os.path.join("../out/guesser", guesser_name, "args.json")) as f:
    guesser_args = json.load(f)
guesser_model = GuesserNetwork(guesser_args, tokenizer, device).to(device)
guesser_model.load_state_dict(
    torch.load(os.path.join("../out/guesser", guesser_name, "params.pth.tar"),
               map_location=device))
guesser_model.eval()

answer_dict = {0: 'Yes', 1: 'No', 2: 'N/A'}
token_to_answer_idx = {
    tokenizer.yes_token: 0,
    tokenizer.no_token: 1,
    tokenizer.non_applicable_token: 2
예제 #10
0
import os
import json
from PIL import Image, ImageDraw, ImageFont
from process_data.data_preprocess import Game, get_games
from utils.draw import mp_show
from process_data.tokenizer import GWTokenizer
from process_data.image_process import get_transform
from process_data.data_preprocess import get_games
data_dir = "./../data"
image_dir = data_dir + "/imgs"
dict_file = os.path.join(data_dir, "dict.json")
tokenizer = GWTokenizer(dict_file)
out_dir = "./../out/"
model_dir = "./models/"

transform = get_transform((224, 224))


def plot_game(game, model_name=None):
    print(game.id)
    filename = os.path.join(image_dir, game.img.filename)
    im = Image.open(filename).convert('RGB')
    qas = len(game.questions)
    w, h = im.size
    image = Image.new('RGB', (w, h+60+qas*30), (255, 255, 255))
    image.paste(im)
    # draw实体
    draw = ImageDraw.Draw(image)
    # font实体
    font = ImageFont.truetype('Arial.ttf', 30)
    # 圈出target object, 蓝色