class ModelOperator:
    def __init__(self, args):

        # set up output directory
        self.output_dir = os.path.join(args.experiment_dir, args.run_name)
        if not os.path.exists(args.experiment_dir):
            os.mkdir(args.experiment_dir)
        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)
        if not os.path.exists(os.path.join(args.experiment_dir,"runs/")):
            os.mkdir(os.path.join(args.experiment_dir,"runs/"))

        # initialize tensorboard writer
        self.runs_dir = os.path.join(args.experiment_dir,"runs/",args.run_name)
        self.writer = SummaryWriter(self.runs_dir)

        # initialize global steps
        self.train_gs = 0
        self.val_gs = 0

        # initialize model config
        self.config = ModelConfig(args)

        # check if there is a model to load
        if args.old_model_dir is not None:
            self.use_old_model = True
            self.load_dir = args.old_model_dir
            self.config.load_from_file(
                os.path.join(self.load_dir, "config.json"))

            # create vocab
            self.vocab = Vocab()
            self.vocab.load_from_dict(os.path.join(self.load_dir, "vocab.json"))
            self.update_vocab = False
            self.config.min_count=1
        else:
            self.use_old_model = False

            self.vocab = None
            self.update_vocab = True

        # create data sets
        self.dataset_filename = args.dataset_filename

        # train
        self.train_dataset = DialogueDataset(
            os.path.join(self.dataset_filename, "train.csv"),
            self.config.history_len,
            self.config.response_len,
            self.vocab,
            self.update_vocab)
        self.data_loader_train = torch.utils.data.DataLoader(
            self.train_dataset, self.config.train_batch_size, shuffle=True)
        self.config.train_len = len(self.train_dataset)

        self.vocab = self.train_dataset.vocab

        # eval
        self.val_dataset = DialogueDataset(
            os.path.join(self.dataset_filename, "val.csv"),
            self.config.history_len,
            self.config.response_len,
            self.vocab,
            self.update_vocab)
        self.data_loader_val = torch.utils.data.DataLoader(
            self.val_dataset, self.config.val_batch_size, shuffle=True)
        self.config.val_len = len(self.val_dataset)

        # update, and save vocab
        self.vocab = self.val_dataset.vocab
        self.train_dataset.vocab = self.vocab
        if (self.config.min_count > 1):
            self.config.old_vocab_size = len(self.vocab)
            self.vocab.prune_vocab(self.config.min_count)
        self.vocab.save_to_dict(os.path.join(self.output_dir, "vocab.json"))
        self.vocab_size = len(self.vocab)
        self.config.vocab_size = self.vocab_size

        # print and save the config file
        self.config.print_config(self.writer)
        self.config.save_config(os.path.join(self.output_dir, "config.json"))

        # set device
        self.device = torch.device('cuda')

        # create model
        self.model = Transformer(
            self.config.vocab_size,
            self.config.vocab_size,
            self.config.history_len,
            self.config.response_len,
            d_word_vec=self.config.embedding_dim,
            d_model=self.config.model_dim,
            d_inner=self.config.inner_dim,
            n_layers=self.config.num_layers,
            n_head=self.config.num_heads,
            d_k=self.config.dim_k,
            d_v=self.config.dim_v,
            dropout=self.config.dropout
        ).to(self.device)

        # create optimizer
        self.optimizer = torch.optim.Adam(
            filter(lambda x: x.requires_grad, self.model.parameters()),
            betas=(0.9, 0.98), eps=1e-09)

        # load old model, optimizer if there is one
        if self.use_old_model:
            self.model, self.optimizer = load_checkpoint(
                os.path.join(self.load_dir, "model.bin"),
                self.model, self.optimizer, self.device)


        # create a sceduled optimizer object
        self.optimizer = ScheduledOptim(
            self.optimizer, self.config.model_dim, self.config.warmup_steps)

        #self.optimizer.optimizer.to(torch.device('cpu'))


    def train(self, num_epochs):
        metrics = {"best_epoch":0, "lowest_loss":99999999999999}

        # output an example
        #self.output_example(0)

        for epoch in range(num_epochs):
           # self.writer.add_graph(self.model)
            #self.writer.add_embedding(
            #    self.model.encoder.src_word_emb.weight, global_step=epoch)

            epoch_metrics = dict()

            # train
            epoch_metrics["train"] = self.execute_phase(epoch, "train")
            # save metrics
            metrics["epoch_{}".format(epoch)] = epoch_metrics
            with open(os.path.join(self.output_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)

            # validate
            epoch_metrics["val"] = self.execute_phase(epoch, "val")
            # save metrics
            metrics["epoch_{}".format(epoch)] = epoch_metrics
            with open(os.path.join(self.output_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)

            # save checkpoint
            #TODO: fix this b
            #if epoch_metrics["val"]["loss"] < metrics["lowest_loss"]:
            #if epoch_metrics["train"]["loss"] < metrics["lowest_loss"]:
            if epoch % 100 == 0:
                self.save_checkpoint(os.path.join(self.output_dir, "model_{}.bin".format(epoch)))
                #metrics["lowest_loss"] = epoch_metrics["train"]["loss"]
                #metrics["best_epoch"] = epoch

            # record metrics to tensorboard
            self.writer.add_scalar("training loss total",
                epoch_metrics["train"]["loss"], global_step=epoch)
            self.writer.add_scalar("val loss total",
                epoch_metrics["val"]["loss"], global_step=epoch)

            self.writer.add_scalar("training perplexity",
                epoch_metrics["train"]["perplexity"], global_step=epoch)
            self.writer.add_scalar("val perplexity",
                epoch_metrics["val"]["perplexity"], global_step=epoch)

            self.writer.add_scalar("training time",
                epoch_metrics["train"]["time_taken"], global_step=epoch)
            self.writer.add_scalar("val time",
                epoch_metrics["val"]["time_taken"], global_step=epoch)

            self.writer.add_scalar("train_bleu_1",
                epoch_metrics["train"]["bleu_1"], global_step=epoch)
            self.writer.add_scalar("val_bleu_1",
                epoch_metrics["val"]["bleu_1"], global_step=epoch)
            self.writer.add_scalar("train_bleu_2",
                epoch_metrics["train"]["bleu_2"], global_step=epoch)
            self.writer.add_scalar("val_bleu_2",
                epoch_metrics["val"]["bleu_2"], global_step=epoch)

            # output an example
            #self.output_example(epoch+1)

        self.writer.close()

    def execute_phase(self, epoch, phase):
        if phase == "train":
            self.model.train()
            dataloader = self.data_loader_train
            batch_size = self.config.train_batch_size
            train = True
        else:
            self.model.eval()
            dataloader = self.data_loader_val
            batch_size = self.config.val_batch_size
            train = False

        start = time.clock()
        phase_metrics = dict()
        epoch_loss = list()
        epoch_bleu_1 = list()
        epoch_bleu_2 = list()
        average_epoch_loss = None
        n_word_total = 0
        n_correct = 0
        n_word_correct = 0
        for i, batch in enumerate(tqdm(dataloader,
                          mininterval=2, desc=phase, leave=False)):
            # prepare data
            src_seq, src_pos, src_seg, tgt_seq, tgt_pos = map(
                lambda x: x.to(self.device), batch)

            gold = tgt_seq[:, 1:]

            # forward
            if train:
                self.optimizer.zero_grad()
            pred = self.model(src_seq, src_pos, src_seg, tgt_seq, tgt_pos)

            # get loss
            loss, n_correct = cal_performance(pred, gold,
                smoothing=self.config.label_smoothing)
            #average_loss = float(loss)/self.config.val_batch_size
            average_loss = float(loss)
            epoch_loss.append(average_loss)
            average_epoch_loss = np.mean(epoch_loss)

            if train:
                self.writer.add_scalar("train_loss",
                    average_loss, global_step=i + epoch * self.config.train_batch_size)
                # backward
                loss.backward()

                # update parameters
                self.optimizer.step_and_update_lr()

            # get_bleu
            output = torch.argmax(pred.view(-1, self.config.response_len-1, self.vocab_size), dim=2)
            epoch_bleu_1.append(bleu(gold, output, 1))
            epoch_bleu_2.append(bleu(gold, output, 2))

            # get_accuracy
            non_pad_mask = gold.ne(src.transformer.Constants.PAD)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct


        phase_metrics["loss"] = average_epoch_loss
        phase_metrics["token_accuracy"] = n_correct / n_word_total

        perplexity = np.exp(average_epoch_loss)
        phase_metrics["perplexity"] = perplexity

        phase_metrics["bleu_1"] = np.mean(epoch_bleu_1)
        phase_metrics["bleu_2"] = np.mean(epoch_bleu_2)

        phase_metrics["time_taken"] = time.clock() - start
        string = ' {} loss: {:.3f} '.format(phase, average_epoch_loss)
        print(string, end='\n')
        return phase_metrics

    def save_checkpoint(self, filename):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.optimizer.state_dict()
        }
        torch.save(state, filename)

    def output_example(self, epoch):
        random_index = random.randint(0, len(self.val_dataset))
        example = self.val_dataset[random_index]

        # prepare data
        src_seq, src_pos, src_seg, tgt_seq, tgt_pos = map(
            lambda x: torch.from_numpy(x).to(self.device).unsqueeze(0), example)

        # take out first token from target for some reason
        gold = tgt_seq[:, 1:]

        # forward
        pred = self.model(src_seq, src_pos, src_seg, tgt_seq, tgt_pos)
        output = torch.argmax(pred, dim=1)

        # get history text
        string = "history: "

        seg = -1
        for i, idx in enumerate(src_seg.squeeze()):
            if seg != idx.item():
                string+="\n"
                seg=idx.item()
            token = self.vocab.id2token[src_seq.squeeze()[i].item()]
            if token != '<blank>':
                string += "{} ".format(token)

        # get target text
        string += "\nTarget:\n"

        for idx in tgt_seq.squeeze():
            token = self.vocab.id2token[idx.item()]
            string += "{} ".format(token)

        # get prediction
        string += "\n\nPrediction:\n"

        for idx in output:
            token = self.vocab.id2token[idx.item()]
            string += "{} ".format(token)

        # print
        print("\n------------------------\n")
        print(string)
        print("\n------------------------\n")

        # add result to tensorboard
        self.writer.add_text("example_output", string, global_step=epoch)
        self.writer.add_histogram("example_vocab_ranking", pred, global_step=epoch)
        self.writer.add_histogram("example_vocab_choice", output,global_step=epoch)
class ModelOperator:
    def __init__(self, args):

        # set up output directory
        self.output_dir = os.path.join(args.experiment_dir, args.run_name)
        if not os.path.exists(args.experiment_dir):
            os.mkdir(args.experiment_dir)
        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)
        if not os.path.exists(os.path.join(args.experiment_dir,"runs/")):
            os.mkdir(os.path.join(args.experiment_dir,"runs/"))

        # initialize tensorboard writer
        self.runs_dir = os.path.join(args.experiment_dir,"runs/",args.run_name)
        self.writer = SummaryWriter(self.runs_dir)

        # initialize global steps
        self.train_gs = 0
        self.val_gs = 0

        # initialize model config
        self.config = ModelConfig(args)

        # check if there is a model to load
        if args.old_model_dir is not None:
            self.use_old_model = True
            self.load_dir = args.old_model_dir
            self.config.load_from_file(
                os.path.join(self.load_dir, "config.json"))

            # create vocab
            self.vocab = Vocab()
            self.vocab.load_from_dict(os.path.join(self.load_dir, "vocab.json"))
            self.update_vocab = False
            self.config.min_count=1
        else:
            self.use_old_model = False

            self.vocab = None
            self.update_vocab = True

        # create data sets
        self.dataset_filename = args.dataset_filename

        # train
        self.train_dataset = DialogueDataset(
            os.path.join(self.dataset_filename, "train_data.json"),
            self.config.sentence_len,
            self.vocab,
            self.update_vocab)
        self.data_loader_train = torch.utils.data.DataLoader(
            self.train_dataset, self.config.train_batch_size, shuffle=True)
        self.config.train_len = len(self.train_dataset)

        self.vocab = self.train_dataset.vocab

        # eval
        self.val_dataset = DialogueDataset(
            os.path.join(self.dataset_filename, "val_data.json"),
            self.config.sentence_len,
            self.vocab,
            self.update_vocab)
        self.data_loader_val = torch.utils.data.DataLoader(
            self.val_dataset, self.config.val_batch_size, shuffle=True)
        self.config.val_len = len(self.val_dataset)

        # update, and save vocab
        self.vocab = self.val_dataset.vocab
        self.train_dataset.vocab = self.vocab
        if (self.config.min_count > 1):
            self.config.old_vocab_size = len(self.vocab)
            self.vocab.prune_vocab(self.config.min_count)
        self.vocab.save_to_dict(os.path.join(self.output_dir, "vocab.json"))
        self.vocab_size = len(self.vocab)
        self.config.vocab_size = self.vocab_size

        # load embeddings
        if self.config.pretrained_embeddings_dir is None:
            pretrained_embeddings = get_pretrained_embeddings(self.config.pretrained_embeddings_dir , self.vocab)
        else:
            pretrained_embeddings = None

        # print and save the config file
        self.config.print_config(self.writer)
        self.config.save_config(os.path.join(self.output_dir, "config.json"))

        # set device
        self.device = torch.device('cuda')

        # create model
        self.model = Transformer(
            self.config.vocab_size,
            self.config.label_len,
            self.config.sentence_len,
            d_word_vec=self.config.embedding_dim,
            d_model=self.config.model_dim,
            d_inner=self.config.inner_dim,
            n_layers=self.config.num_layers,
            n_head=self.config.num_heads,
            d_k=self.config.dim_k,
            d_v=self.config.dim_v,
            dropout=self.config.dropout,
            pretrained_embeddings=pretrained_embeddings
        ).to(self.device)

        # create optimizer
        self.optimizer = torch.optim.Adam(
            filter(lambda x: x.requires_grad, self.model.parameters()),
            betas=(0.9, 0.98), eps=1e-09)

        # load old model, optimizer if there is one
        if self.use_old_model:
            self.model, self.optimizer = load_checkpoint(
                os.path.join(self.load_dir, "model.bin"),
                self.model, self.optimizer, self.device)


        # create a sceduled optimizer object
        self.optimizer = ScheduledOptim(
            self.optimizer, self.config.model_dim, self.config.warmup_steps)

        #self.optimizer.optimizer.to(torch.device('cpu'))


    def train(self, num_epochs):
        metrics = {"best_epoch":0, "highest_f1":0}

        # output an example
        self.output_example(0)

        for epoch in range(num_epochs):
            #self.writer.add_graph(self.model)
            #self.writer.add_embedding(
            #    self.model.encoder.src_word_emb.weight, global_step=epoch)

            epoch_metrics = dict()

            # train
            epoch_metrics["train"] = self.execute_phase(epoch, "train")
            # save metrics
            metrics["epoch_{}".format(epoch)] = epoch_metrics
            with open(os.path.join(self.output_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)

            # validate
            epoch_metrics["val"] = self.execute_phase(epoch, "val")
            # save metrics
            metrics["epoch_{}".format(epoch)] = epoch_metrics
            with open(os.path.join(self.output_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)

            # save checkpoint
            #TODO: fix this b
            if epoch_metrics["val"]["avg_results"]["F1"] > metrics["highest_f1"]:
            #if epoch_metrics["train"]["loss"] < metrics["lowest_loss"]:
            #if epoch % 100 == 0:
                self.save_checkpoint(os.path.join(self.output_dir, "model.bin"))
                metrics["lowest_f1"] = epoch_metrics["val"]["avg_results"]["F1"]
                metrics["best_epoch"] = epoch

                test_results = self.get_test_predictions(
                    os.path.join(self.dataset_filename, "test_data.json"),
                    os.path.join(self.output_dir, "predictions{}.json".format(epoch)))

            # record metrics to tensorboard
            self.writer.add_scalar("training loss total",
                epoch_metrics["train"]["loss"], global_step=epoch)
            self.writer.add_scalar("val loss total",
                epoch_metrics["val"]["loss"], global_step=epoch)


            self.writer.add_scalar("training time",
                epoch_metrics["train"]["time_taken"], global_step=epoch)
            self.writer.add_scalar("val time",
                epoch_metrics["val"]["time_taken"], global_step=epoch)

            self.writer.add_scalars("train_results", epoch_metrics["train"]["avg_results"], global_step=epoch)
            self.writer.add_scalars("val_results", epoch_metrics["val"]["avg_results"],
                                    global_step=epoch)
            # output an example
            self.output_example(epoch+1)

        self.writer.close()

    def execute_phase(self, epoch, phase):
        if phase == "train":
            self.model.train()
            dataloader = self.data_loader_train
            batch_size = self.config.train_batch_size
            train = True
        else:
            self.model.eval()
            dataloader = self.data_loader_val
            batch_size = self.config.val_batch_size
            train = False

        start = time.clock()
        phase_metrics = dict()
        epoch_loss = list()
        epoch_metrics = list()
        results = {"accuracy": list(), "precision": list(), "recall": list(), "F1": list()}

        average_epoch_loss = None
        for i, batch in enumerate(tqdm(dataloader,
                          mininterval=2, desc=phase, leave=False)):
            # prepare data
            src_seq, src_pos, src_seg, tgt= map(
                lambda x: x.to(self.device), batch[:4])

            ids = batch[4]
            start_end_idx = batch[5]

            # forward
            if train:
                self.optimizer.zero_grad()
            pred = self.model(src_seq, src_pos, src_seg, tgt)

            loss = F.cross_entropy(self.prepare_pred(pred).view(-1, 2), tgt.view(-1))

            average_loss = float(loss)
            epoch_loss.append(average_loss)
            average_epoch_loss = np.mean(epoch_loss)

            if train:
                self.writer.add_scalar("train_loss",
                    average_loss, global_step=i + epoch * self.config.train_batch_size)
                # backward
                loss.backward()

                # update parameters
                self.optimizer.step_and_update_lr()
            output = torch.argmax(self.prepare_pred(pred), 3)
            get_results(tgt.view(-1).cpu(), output.view(-1).cpu(), results)

        phase_metrics["avg_results"] = {key: np.mean(value) for key, value in results.items()}
        phase_metrics["loss"] = average_epoch_loss

        phase_metrics["time_taken"] = time.clock() - start
        string = ' {} loss: {:.3f} '.format(phase, average_epoch_loss)
        print(string, end='\n')
        return phase_metrics

    def get_test_predictions(self, test_filename, save_filename):
        test_dataset = DialogueDataset(
            test_filename,
            self.config.sentence_len,
            self.vocab,
            False)

        test_data_loader = torch.utils.data.DataLoader(
            test_dataset, self.config.val_batch_size, shuffle=True)

        with open(test_filename, 'r') as f:
            data = json.load(f)

        start = time.clock()
        phase_metrics = dict()
        epoch_loss = list()
        epoch_metrics = list()
        results = {"accuracy": list(), "precision": list(), "recall": list(),
                   "F1": list()}
        average_epoch_loss = None
        for i, batch in enumerate(tqdm(test_data_loader,
                                       mininterval=2, desc='test', leave=False)):
            # prepare data
            src_seq, src_pos, src_seg, tgt = map(
                lambda x: x.to(self.device), batch[:4])

            ids = batch[4]
            start_end_idx = batch[5]

            # forward
            pred = self.model(src_seq, src_pos, src_seg, tgt)

            loss = F.cross_entropy(self.prepare_pred(pred).view(-1, 2),
                                   tgt.view(-1))

            average_loss = float(loss)
            epoch_loss.append(average_loss)
            average_epoch_loss = np.mean(epoch_loss)

            output = torch.argmax(self.prepare_pred(pred), 3)

            record_predictions(output, data, ids, start_end_idx)

            get_results(tgt.view(-1).cpu(), output.view(-1).cpu(), results)

        phase_metrics["avg_results"] = {key: np.mean(value) for key, value in
                                        results.items()}
        phase_metrics["loss"] = average_epoch_loss

        phase_metrics["time_taken"] = time.clock() - start
        string = ' {} loss: {:.3f} '.format('test', average_epoch_loss)
        print(string, end='\n')

        data["results"] = phase_metrics

        with open(save_filename, 'w') as f:
            json.dump(data, f)

        return phase_metrics



    def save_checkpoint(self, filename):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.optimizer.state_dict()
        }
        torch.save(state, filename)

    def output_example(self, epoch):
        random_index = random.randint(0, len(self.val_dataset))
        example = self.val_dataset[random_index]

        # prepare data
        src_seq, src_pos, src_seg, tgt_seq = map(
            lambda x: torch.from_numpy(x).to(self.device).unsqueeze(0), example[:4])

        # take out first token from target for some reason
        gold = tgt_seq[:, 1:]

        # forward
        pred = self.model(src_seq, src_pos, src_seg, tgt_seq)
        output = self.prepare_pred(pred).squeeze(0)

        words = src_seq.tolist()[0]
        target_strings = labels_2_mention_str(tgt_seq.squeeze(0))
        output_strings = labels_2_mention_str(torch.argmax(output, dim=2))

        # get history text
        string = "word: output - target\n"

        for word, t, o in zip(words, target_strings, output_strings):
            token = self.vocab.id2token[word]
            if token != "<blank>":
                string += "[{}: {} - {}], \n".format(token, o, t)

        # print
        print("\n------------------------\n")
        print(string)
        print("\n------------------------\n")

        # add result to tensorboard
        self.writer.add_text("example_output", string, global_step=epoch)
        self.writer.add_histogram("example_vocab_ranking", pred, global_step=epoch)
        self.writer.add_histogram("example_vocab_choice", output,global_step=epoch)

    def prepare_pred(self, pred):
        temp = pred
        pred = pred.view(-1)
        size = pred.size()
        nullclass = torch.ones(size, dtype=pred.dtype, device=self.device)
        nullclass -= pred
        pred = torch.stack((nullclass, pred), 1).view(-1,
                                                       self.config.sentence_len,
                                                       self.config.label_len,
                                                       2)
        return pred
class ChatBot:
    def __init__(self, args):
        # get the dir with pre-trained model

        load_dir = os.path.join(args.experiment_dir, args.old_model_dir)

        # initialize, and load vocab
        self.vocab = Vocab()
        vocab_filename = os.path.join(load_dir, "vocab.json")
        self.vocab.load_from_dict(vocab_filename)

        # load configuration
        with open(os.path.join(load_dir, "config.json"), "r") as f:
            config = json.load(f)

        args.response_len = config["response_len"]
        args.history_len = config["history_len"]

        # initialize an empty dataset. used to get input features
        self.dataset = DialogueDataset(None,
                                       history_len=config["history_len"],
                                       response_len=config["response_len"],
                                       vocab=self.vocab,
                                       update_vocab=False)

        # set device
        self.device = torch.device(args.device)

        # initialize model
        model = Transformer(config["vocab_size"],
                            config["vocab_size"],
                            config["history_len"],
                            config["response_len"],
                            d_word_vec=config["embedding_dim"],
                            d_model=config["model_dim"],
                            d_inner=config["inner_dim"],
                            n_layers=config["num_layers"],
                            n_head=config["num_heads"],
                            d_k=config["dim_k"],
                            d_v=config["dim_v"],
                            dropout=config["dropout"],
                            pretrained_embeddings=None).to(self.device)

        # load checkpoint
        checkpoint = torch.load(os.path.join(load_dir, args.old_model_name),
                                map_location=self.device)
        model.load_state_dict(checkpoint['model'])

        # create chatbot
        self.chatbot = Chatbot(args, model)

        self.args = args

    def run(self):
        logging.basicConfig(
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            level=logging.INFO)

        greeting_text = "Hello! I am the Hawkbot! Let me tell you about myself" \
            "... please dont hurt my feelings!" \
                        "  If you would like to reset "\
                        "the conversation, please type '/reset'. "

        # initialize history dictionary for each chat id
        history = dict()

        def greeting(bot, update):
            # reset history, or create new history for chat id
            if update.message.chat_id in history:
                id = "{}_history".format(update.message.chat_id)
                if id in history:
                    history[id].append(history[update.message.chat_id])
                else:
                    history[id] = [history[update.message.chat_id]]
                history[update.message.chat_id].clear()
            else:
                history[update.message.chat_id] = list()

            # send a message
            bot.send_message(update.message.chat_id, greeting_text)

        def respond(bot, update):
            # initialize history for chat if it doesnt exist
            if update.message.chat_id not in history:
                greeting(bot, update)
            else:
                # get message, and add to history
                message = update.message.text
                history[update.message.chat_id].append(message)
                # get response, and add to history
                response = self._print_response(
                    history[update.message.chat_id])
                history[update.message.chat_id].append(response)
                # send response from user
                bot.send_message(update.message.chat_id,
                                 clean_response(response))

                with open(self.args.save_filename, 'w') as f:
                    json.dump({
                        "history": history,
                        "args": vars(self.args)
                    },
                              f,
                              indent=4)

        # queries sent to: https://api.telegram.org/bot<token>/METHOD_NAME
        TOKEN = self.args.token

        bot = TelegramBot(TOKEN)
        bot.add_handler(MessageHandler(Filters.text, respond))
        bot.add_handler(CommandHandler('reset', greeting))

    # print the response from the input
    def _print_response(self, history):

        # generate responses
        responses, scores = self._generate_responses(history)
        # chose response
        if self.args.choose_best:
            response = responses[0][0]
        else:
            # pick a random result from the n_best
            idx = random.randint(
                0,
                min(self.args.n_best, self.args.beam_size) - 1)
            response = responses[0][idx]

        # uncomment this line to see all the scores
        # print("scores in log prob: {}\n".format(scores[0]))

        # create output string
        output = ""
        for idx in response[:-1]:
            token = self.vocab.id2token[idx]
            output += "{} ".format(token)
        print(f'{history[-1]} -> {output}')
        return output

    def _generate_responses(self, history):
        # get input features for the dialogue history
        h_seq, h_pos, h_seg = self.dataset.get_input_features(history)

        # get response from model
        response = self.chatbot.translate_batch(h_seq, h_pos, h_seg)
        return response