コード例 #1
0
ファイル: test_dict.py プロジェクト: ahiroto/ParlAI
    def test_basic_parse(self):
        """Check that the dictionary is correctly adding and parsing short
        sentence.
        """
        from parlai.core.dict import DictionaryAgent
        from parlai.core.params import ParlaiParser

        argparser = ParlaiParser()
        DictionaryAgent.add_cmdline_args(argparser)
        opt = argparser.parse_args()
        dictionary = DictionaryAgent(opt)
        num_builtin = len(dictionary)

        dictionary.observe({'text': 'hello world'})
        dictionary.act()
        assert len(dictionary) - num_builtin == 2

        vec = dictionary.parse('hello world')
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=list)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=tuple)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1
コード例 #2
0
ファイル: test_dict.py プロジェクト: simplecoka/cortx
    def test_basic_parse(self):
        """
        Check the dictionary is correctly adding and parsing short sentence.
        """
        parser = ParlaiParser()
        DictionaryAgent.add_cmdline_args(parser, partial_opt=None)
        opt = parser.parse_args([])
        dictionary = DictionaryAgent(opt)
        num_builtin = len(dictionary)

        dictionary.observe({'text': 'hello world'})
        dictionary.act()
        assert len(dictionary) - num_builtin == 2

        vec = dictionary.parse('hello world')
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=list)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=tuple)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1
コード例 #3
0
    def test_basic_parse(self):
        """Check that the dictionary is correctly adding and parsing short
        sentence.
        """
        from parlai.core.dict import DictionaryAgent
        from parlai.core.params import ParlaiParser

        argparser = ParlaiParser()
        DictionaryAgent.add_cmdline_args(argparser)
        opt = argparser.parse_args(print_args=False)
        dictionary = DictionaryAgent(opt)
        num_builtin = len(dictionary)

        dictionary.observe({'text': 'hello world'})
        dictionary.act()
        assert len(dictionary) - num_builtin == 2

        vec = dictionary.parse('hello world')
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=list)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=tuple)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1
コード例 #4
0
class Seq2seqAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Sequence to Sequence Learning with Neural
    Networks `(Sutskever et al. 2014) <https://arxiv.org/abs/1409.3215>`_.
    """

    OPTIM_OPTS = {
        'adadelta': optim.Adadelta,
        'adagrad': optim.Adagrad,
        'adam': optim.Adam,
        'adamax': optim.Adamax,
        'asgd': optim.ASGD,
        'lbfgs': optim.LBFGS,
        'rmsprop': optim.RMSprop,
        'rprop': optim.Rprop,
        'sgd': optim.SGD,
    }

    ENC_OPTS = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs', '--hiddensize', type=int, default=128,
                           help='size of the hidden layers')
        agent.add_argument('-emb', '--embeddingsize', type=int, default=128,
                           help='size of the token embeddings')
        agent.add_argument('-nl', '--numlayers', type=int, default=2,
                           help='number of hidden layers')
        agent.add_argument('-lr', '--learningrate', type=float, default=0.5,
                           help='learning rate')
        agent.add_argument('-dr', '--dropout', type=float, default=0.1,
                           help='dropout rate')
        agent.add_argument('-att', '--attention', type=int, default=0,
                           help='if greater than 0, use attention of specified'
                                ' length while decoding')
        agent.add_argument('--no-cuda', action='store_true', default=False,
                           help='disable GPUs even if available')
        agent.add_argument('--gpu', type=int, default=-1,
                           help='which GPU device to use')
        agent.add_argument('-rc', '--rank-candidates', type='bool',
                           default=False,
                           help='rank candidates if available. this is done by'
                                ' computing the mean score per token for each '
                                'candidate and selecting the highest scoring.')
        agent.add_argument('-tr', '--truncate', type='bool', default=True,
                           help='truncate input & output lengths to speed up '
                           'training (may reduce accuracy). This fixes all '
                           'input and output to have a maximum length and to '
                           'be similar in length to one another by throwing '
                           'away extra tokens. This reduces the total amount '
                           'of padding in the batches.')
        agent.add_argument('-enc', '--encoder', default='gru',
                           choices=Seq2seqAgent.ENC_OPTS.keys(),
                           help='Choose between different encoder modules.')
        agent.add_argument('-dec', '--decoder', default='same',
                           choices=['same', 'shared'] + list(Seq2seqAgent.ENC_OPTS.keys()),
                           help='Choose between different decoder modules. '
                                'Default "same" uses same class as encoder, '
                                'while "shared" also uses the same weights.')
        agent.add_argument('-opt', '--optimizer', default='sgd',
                           choices=Seq2seqAgent.OPTIM_OPTS.keys(),
                           help='Choose between pytorch optimizers. '
                                'Any member of torch.optim is valid and will '
                                'be used with default params except learning '
                                'rate (as specified by -lr).')

    def __init__(self, opt, shared=None):
        """Set up model if shared params not set, otherwise no work to do."""
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            # check for cuda
            self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available()
            if self.use_cuda:
                print('[ Using CUDA ]')
                torch.cuda.set_device(opt['gpu'])

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' + opt['model_file'])
                new_opt, self.states = self.load(opt['model_file'])
                # override options with stored ones
                opt = self.override_opt(new_opt)

            self.dict = DictionaryAgent(opt)
            self.id = 'Seq2Seq'
            # we use START markers to start our output
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))
            # we use END markers to end our output
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            # get index of null token from dictionary (probably 0)
            self.NULL_IDX = self.dict.txt2vec(self.dict.null_token)[0]

            # store important params directly
            hsz = opt['hiddensize']
            emb = opt['embeddingsize']
            self.hidden_size = hsz
            self.emb_size = emb
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learningrate']
            self.rank = opt['rank_candidates']
            self.longest_label = 1
            self.truncate = opt['truncate']
            self.attention = opt['attention']

            # set up tensors
            self.zeros = torch.zeros(self.num_layers, 1, hsz)
            self.xs = torch.LongTensor(1, 1)
            self.ys = torch.LongTensor(1, 1)
            self.cands = torch.LongTensor(1, 1, 1)
            self.cand_scores = torch.FloatTensor(1)
            self.cand_lengths = torch.LongTensor(1)

            # set up modules
            self.criterion = nn.NLLLoss()
            # lookup table stores word embeddings
            self.lt = nn.Embedding(len(self.dict), emb,
                                   padding_idx=self.NULL_IDX,
                                   scale_grad_by_freq=True)
            self.lt2enc = nn.Linear(emb, hsz)
            self.lt2dec = nn.Linear(emb, hsz)
            # encoder captures the input text
            enc_class = Seq2seqAgent.ENC_OPTS[opt['encoder']]
            self.encoder = enc_class(hsz, hsz, opt['numlayers'])
            # decoder produces our output states
            if opt['decoder'] == 'shared':
                self.decoder = self.encoder
            elif opt['decoder'] == 'same':
                self.decoder = enc_class(hsz, hsz, opt['numlayers'])
            else:
                dec_class = Seq2seqAgent.ENC_OPTS[opt['decoder']]
                self.decoder = dec_class(hsz, hsz, opt['numlayers'])
            # linear layer helps us produce outputs from final decoder state
            self.h2o = nn.Linear(hsz, len(self.dict))
            # droput on the linear layer helps us generalize
            self.dropout = nn.Dropout(opt['dropout'])

            self.use_attention = False
            # if attention is greater than 0, set up additional members
            if self.attention > 0:
                self.use_attention = True
                self.max_length = self.attention
                # combines input and previous hidden output layer
                self.attn = nn.Linear(hsz * 2, self.max_length)
                # combines attention weights with encoder outputs
                self.attn_combine = nn.Linear(hsz * 2, hsz)

            # set up optims for each module
            lr = opt['learningrate']

            optim_class = Seq2seqAgent.OPTIM_OPTS[opt['optimizer']]
            self.optims = {
                'lt': optim_class(self.lt.parameters(), lr=lr),
                'lt2enc': optim_class(self.lt2enc.parameters(), lr=lr),
                'lt2dec': optim_class(self.lt2dec.parameters(), lr=lr),
                'encoder': optim_class(self.encoder.parameters(), lr=lr),
                'decoder': optim_class(self.decoder.parameters(), lr=lr),
                'h2o': optim_class(self.h2o.parameters(), lr=lr),
            }

            if hasattr(self, 'states'):
                # set loaded states if applicable
                self.set_states(self.states)

            if self.use_cuda:
                self.cuda()

        self.reset()

    def override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {'hiddensize', 'embeddingsize', 'numlayers', 'optimizer',
                      'encoder', 'decoder'}
        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                      k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def parse(self, text):
        """Convert string to token indices."""
        return self.dict.txt2vec(text)

    def v2t(self, vec):
        """Convert token indices to string of tokens."""
        return self.dict.vec2txt(vec)

    def cuda(self):
        """Push parameters to the GPU."""
        self.START_TENSOR = self.START_TENSOR.cuda(async=True)
        self.END_TENSOR = self.END_TENSOR.cuda(async=True)
        self.zeros = self.zeros.cuda(async=True)
        self.xs = self.xs.cuda(async=True)
        self.ys = self.ys.cuda(async=True)
        self.cands = self.cands.cuda(async=True)
        self.cand_scores = self.cand_scores.cuda(async=True)
        self.cand_lengths = self.cand_lengths.cuda(async=True)
        self.criterion.cuda()
        self.lt.cuda()
        self.lt2enc.cuda()
        self.lt2dec.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        self.h2o.cuda()
        self.dropout.cuda()
        if self.use_attention:
            self.attn.cuda()
            self.attn_combine.cuda()

    def hidden_to_idx(self, hidden, dropout=False):
        """Convert hidden state vectors into indices into the dictionary."""
        if hidden.size(0) > 1:
            raise RuntimeError('bad dimensions of tensor:', hidden)
        hidden = hidden.squeeze(0)
        scores = self.h2o(hidden)
        if dropout:
            scores = self.dropout(scores)
        scores = F.log_softmax(scores)
        _max_score, idx = scores.max(1)
        return idx, scores

    def zero_grad(self):
        """Zero out optimizers."""
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        """Do one optimization step."""
        for optimizer in self.optims.values():
            optimizer.step()

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def observe(self, observation):
        """Save observation for act.
        If multiple observations are from the same episode, concatenate them.
        """
        # shallow copy observation (deep copy can be expensive)
        observation = observation.copy()
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def _encode(self, xs, dropout=False):
        """Call encoder and return output and hidden states."""
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs)
        if dropout:
            xes = self.dropout(xes)
        # project from emb_size to hidden_size dimensions
        xes = self.lt2enc(xes).transpose(0, 1)

        if self.zeros.size(1) != batchsize:
            self.zeros.resize_(self.num_layers, batchsize, self.hidden_size).fill_(0)
        h0 = Variable(self.zeros)
        if type(self.encoder) == nn.LSTM:
            encoder_output, hidden = self.encoder(xes, (h0, h0))
            if type(self.decoder) != nn.LSTM:
                hidden = hidden[0]
        else:
            encoder_output, hidden = self.encoder(xes, h0)
            if type(self.decoder) == nn.LSTM:
                hidden = (hidden, h0)
        encoder_output = encoder_output.transpose(0, 1)

        if self.use_attention:
            if encoder_output.size(1) > self.max_length:
                offset = encoder_output.size(1) - self.max_length
                encoder_output = encoder_output.narrow(1, offset, self.max_length)

        return encoder_output, hidden


    def _apply_attention(self, xes, encoder_output, encoder_hidden):
        """Apply attention to encoder hidden layer."""
        attn_weights = F.softmax(self.attn(torch.cat((xes[0], encoder_hidden[-1]), 1)))

        if attn_weights.size(1) > encoder_output.size(1):
            attn_weights = attn_weights.narrow(1, 0, encoder_output.size(1) )

        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_output).squeeze(1)

        output = torch.cat((xes[0], attn_applied), 1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)

        return output


    def _decode_and_train(self, batchsize, xes, ys, encoder_output, hidden):
        # update the model based on the labels
        self.zero_grad()
        loss = 0

        output_lines = [[] for _ in range(batchsize)]

        # keep track of longest label we've ever seen
        self.longest_label = max(self.longest_label, ys.size(1))
        for i in range(ys.size(1)):
            output = self._apply_attention(xes, encoder_output, hidden) if self.use_attention else xes

            output, hidden = self.decoder(output, hidden)
            preds, scores = self.hidden_to_idx(output, dropout=True)
            y = ys.select(1, i)
            loss += self.criterion(scores, y)
            # use the true token as the next input instead of predicted
            # this produces a biased prediction but better training
            xes = self.lt2dec(self.lt(y).unsqueeze(0))
            for b in range(batchsize):
                # convert the output scores to tokens
                token = self.v2t([preds.data[b]])
                output_lines[b].append(token)

        loss.backward()
        self.update_params()

        if random.random() < 0.1:
            # sometimes output a prediction for debugging
            print('prediction:', ' '.join(output_lines[0]),
                  '\nlabel:', self.dict.vec2txt(ys.data[0]))

        return output_lines

    def _decode_only(self, batchsize, xes, ys, encoder_output, hidden):
        # just produce a prediction without training the model
        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0

        output_lines = [[] for _ in range(batchsize)]

        # now, generate a response from scratch
        while(total_done < batchsize) and max_len < self.longest_label:
            # keep producing tokens until we hit END or max length for each
            # example in the batch
            output = self._apply_attention(xes, encoder_output, hidden) if self.use_attention else xes

            output, hidden = self.decoder(output, hidden)
            preds, scores = self.hidden_to_idx(output, dropout=False)

            xes = self.lt2dec(self.lt(preds.unsqueeze(0)))
            max_len += 1
            for b in range(batchsize):
                if not done[b]:
                    # only add more tokens for examples that aren't done yet
                    token = self.v2t([preds.data[b]])
                    if token == self.END:
                        # if we produced END, we're done
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)

        if random.random() < 0.1:
            # sometimes output a prediction for debugging
            print('prediction:', ' '.join(output_lines[0]))

        return output_lines

    def _score_candidates(self, cands, xe, encoder_output, hidden):
        # score each candidate separately

        # cands are exs_with_cands x cands_per_ex x words_per_cand
        # cview is total_cands x words_per_cand
        cview = cands.view(-1, cands.size(2))
        cands_xes = xe.expand(xe.size(0), cview.size(0), xe.size(2))
        sz = hidden.size()
        cands_hn = (
            hidden.view(sz[0], sz[1], 1, sz[2])
            .expand(sz[0], sz[1], cands.size(1), sz[2])
            .contiguous()
            .view(sz[0], -1, sz[2])
        )

        sz = encoder_output.size()
        cands_encoder_output = (
            encoder_output.contiguous()
            .view(sz[0], 1, sz[1], sz[2])
            .expand(sz[0], cands.size(1), sz[1], sz[2])
            .contiguous()
            .view(-1, sz[1], sz[2])
        )

        cand_scores = Variable(
                    self.cand_scores.resize_(cview.size(0)).fill_(0))
        cand_lengths = Variable(
                    self.cand_lengths.resize_(cview.size(0)).fill_(0))

        for i in range(cview.size(1)):
            output = self._apply_attention(cands_xes, cands_encoder_output, cands_hn) \
                    if self.use_attention else cands_xes

            output, cands_hn = self.decoder(output, cands_hn)
            preds, scores = self.hidden_to_idx(output, dropout=False)
            cs = cview.select(1, i)
            non_nulls = cs.ne(self.NULL_IDX)
            cand_lengths += non_nulls.long()
            score_per_cand = torch.gather(scores, 1, cs.unsqueeze(1))
            cand_scores += score_per_cand.squeeze() * non_nulls.float()
            cands_xes = self.lt2dec(self.lt(cs).unsqueeze(0))

        # set empty scores to -1, so when divided by 0 they become -inf
        cand_scores -= cand_lengths.eq(0).float()
        # average the scores per token
        cand_scores /= cand_lengths.float()

        cand_scores = cand_scores.view(cands.size(0), cands.size(1))
        srtd_scores, text_cand_inds = cand_scores.sort(1, True)
        text_cand_inds = text_cand_inds.data

        return text_cand_inds

    def predict(self, xs, ys=None, cands=None):
        """Produce a prediction from our model.

        Update the model using the targets if available, otherwise rank
        candidates as well if they are available.
        """
        batchsize = len(xs)
        text_cand_inds = None
        is_training = ys is not None
        encoder_output, hidden = self._encode(xs, dropout=is_training)

        # next we use END as an input to kick off our decoder
        x = Variable(self.START_TENSOR)
        xe = self.lt2dec(self.lt(x).unsqueeze(1))
        xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        # list of output tokens for each example in the batch
        output_lines = None

        if is_training:
            output_lines = self._decode_and_train(batchsize, xes, ys,
                                                  encoder_output, hidden)

        else:
            if cands is not None:
                text_cand_inds = self._score_candidates(cands, xe,
                                                        encoder_output, hidden)

            output_lines = self._decode_only(batchsize, xes, ys,
                                             encoder_output, hidden)

        return output_lines, text_cand_inds

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        # tokenize the text
        xs = None
        if batchsize > 0:
            parsed = [self.parse(ex['text']) for ex in exs]
            max_x_len = max([len(x) for x in parsed])
            if self.truncate:
                # shrink xs to to limit batch computation
                min_x_len = min([len(x) for x in parsed])
                max_x_len = min(min_x_len + 12, max_x_len, 48)
                parsed = [x[-max_x_len:] for x in parsed]
            xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
            # pack the data to the right side of the tensor for this model
            for i, x in enumerate(parsed):
                offset = max_x_len - len(x)
                for j, idx in enumerate(x):
                    xs[i][j + offset] = idx
            if self.use_cuda:
                # copy to gpu
                self.xs.resize_(xs.size())
                self.xs.copy_(xs, async=True)
                xs = Variable(self.xs)
            else:
                xs = Variable(xs)

        # set up the target tensors
        ys = None
        if batchsize > 0 and any(['labels' in ex for ex in exs]):
            # randomly select one of the labels to update on, if multiple
            # append END to each label
            labels = [random.choice(ex.get('labels', [''])) + ' ' + self.END for ex in exs]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            if self.truncate:
                # shrink ys to to limit batch computation
                min_y_len = min(len(y) for y in parsed)
                max_y_len = min(min_y_len + 12, max_y_len, 48)
                parsed = [y[:max_y_len] for y in parsed]
            ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                # copy to gpu
                self.ys.resize_(ys.size())
                self.ys.copy_(ys, async=True)
                ys = Variable(self.ys)
            else:
                ys = Variable(ys)

        # set up candidates
        cands = None
        valid_cands = None
        if ys is None and self.rank:
            # only do ranking when no targets available and ranking flag set
            parsed = []
            valid_cands = []
            for i in valid_inds:
                if 'label_candidates' in observations[i]:
                    # each candidate tuple is a pair of the parsed version and
                    # the original full string
                    cs = list(observations[i]['label_candidates'])
                    parsed.append([self.parse(c) for c in cs])
                    valid_cands.append((i, cs))
            if len(parsed) > 0:
                # TODO: store lengths of cands separately, so don't have zero
                # padding for varying number of cands per example
                # found cands, pack them into tensor
                max_c_len = max(max(len(c) for c in cs) for cs in parsed)
                max_c_cnt = max(len(cs) for cs in parsed)
                cands = torch.LongTensor(len(parsed), max_c_cnt, max_c_len).fill_(0)
                for i, cs in enumerate(parsed):
                    for j, c in enumerate(cs):
                        for k, idx in enumerate(c):
                            cands[i][j][k] = idx
                if self.use_cuda:
                    # copy to gpu
                    self.cands.resize_(cands.size())
                    self.cands.copy_(cands, async=True)
                    cands = Variable(self.cands)
                else:
                    cands = Variable(cands)

        return xs, ys, valid_inds, cands, valid_cands

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, ys, valid_inds, cands, valid_cands = self.batchify(observations)

        if xs is None:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions either way, but use the targets if available

        predictions, text_cand_inds = self.predict(xs, ys, cands)

        for i in range(len(predictions)):
            # map the predictions back to non-empty examples in the batch
            # we join with spaces since we produce tokens one at a time
            curr = batch_reply[valid_inds[i]]
            curr['text'] = ' '.join(c for c in predictions[i] if c != self.END
                                    and c != self.dict.null_token)

        if text_cand_inds is not None:
            for i in range(len(valid_cands)):
                order = text_cand_inds[i]
                batch_idx, curr_cands = valid_cands[i]
                curr = batch_reply[batch_idx]
                curr['text_candidates'] = [curr_cands[idx] for idx in order
                                           if idx < len(curr_cands)]

        return batch_reply

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path and hasattr(self, 'lt'):
            model = {}
            model['lt'] = self.lt.state_dict()
            model['lt2enc'] = self.lt2enc.state_dict()
            model['lt2dec'] = self.lt2dec.state_dict()
            model['encoder'] = self.encoder.state_dict()
            model['decoder'] = self.decoder.state_dict()
            model['h2o'] = self.h2o.state_dict()
            model['optims'] = {k: v.state_dict()
                               for k, v in self.optims.items()}
            model['longest_label'] = self.longest_label
            model['opt'] = self.opt

            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)

        return model['opt'], model

    def set_states(self, states):
        """Set the state dicts of the modules from saved states."""
        self.lt.load_state_dict(states['lt'])
        self.lt2enc.load_state_dict(states['lt2enc'])
        self.lt2dec.load_state_dict(states['lt2dec'])
        self.encoder.load_state_dict(states['encoder'])
        self.decoder.load_state_dict(states['decoder'])
        self.h2o.load_state_dict(states['h2o'])
        for k, v in states['optims'].items():
            self.optims[k].load_state_dict(v)
        self.longest_label = states['longest_label']
コード例 #5
0
class Seq2seqAgent(Agent):
    """Simple agent which uses an LSTM to process incoming text observations."""
    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs',
                           '--hiddensize',
                           type=int,
                           default=64,
                           help='size of the hidden layers and embeddings')
        agent.add_argument('-nl',
                           '--numlayers',
                           type=int,
                           default=2,
                           help='number of hidden layers')
        agent.add_argument('-lr',
                           '--learningrate',
                           type=float,
                           default=0.5,
                           help='learning rate')
        agent.add_argument('-dr',
                           '--dropout',
                           type=float,
                           default=0.1,
                           help='dropout rate')
        agent.add_argument('--no-cuda',
                           action='store_true',
                           default=False,
                           help='disable GPUs even if available')
        agent.add_argument('--gpu',
                           type=int,
                           default=-1,
                           help='which GPU device to use')

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.set_device(opt['gpu'])
        if not shared:
            self.dict = DictionaryAgent(opt)
            self.id = 'Seq2Seq'
            hsz = opt['hiddensize']
            self.EOS = self.dict.eos_token
            self.EOS_TENSOR = torch.LongTensor(self.dict.parse(self.EOS))
            self.hidden_size = hsz
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learningrate']
            self.use_cuda = opt.get('cuda', False)
            self.longest_label = 1

            self.criterion = nn.NLLLoss()
            self.lt = nn.Embedding(len(self.dict),
                                   hsz,
                                   padding_idx=0,
                                   scale_grad_by_freq=True)
            self.encoder = nn.GRU(hsz, hsz, opt['numlayers'])
            self.decoder = nn.GRU(hsz, hsz, opt['numlayers'])
            self.d2o = nn.Linear(hsz, len(self.dict))
            self.dropout = nn.Dropout(opt['dropout'])
            self.softmax = nn.LogSoftmax()

            lr = opt['learningrate']
            self.optims = {
                'lt': optim.SGD(self.lt.parameters(), lr=lr),
                'encoder': optim.SGD(self.encoder.parameters(), lr=lr),
                'decoder': optim.SGD(self.decoder.parameters(), lr=lr),
                'd2o': optim.SGD(self.d2o.parameters(), lr=lr),
            }
            if self.use_cuda:
                self.cuda()
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                print('Loading existing model parameters from ' +
                      opt['model_file'])
                self.load(opt['model_file'])

        self.episode_done = True

    def parse(self, text):
        return torch.LongTensor(self.dict.txt2vec(text))

    def v2t(self, vec):
        return self.dict.vec2txt(vec)

    def cuda(self):
        self.criterion.cuda()
        self.lt.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        self.d2o.cuda()
        self.dropout.cuda()
        self.softmax.cuda()

    def hidden_to_idx(self, hidden, drop=False):
        if hidden.size(0) > 1:
            raise RuntimeError('bad dimensions of tensor:', hidden)
        hidden = hidden.squeeze(0)
        scores = self.d2o(hidden)
        if drop:
            scores = self.dropout(scores)
        scores = self.softmax(scores)
        _max_score, idx = scores.max(1)
        return idx, scores

    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()

    def init_zeros(self, bsz=1):
        t = torch.zeros(self.num_layers, bsz, self.hidden_size)
        if self.use_cuda:
            t = t.cuda(async=True)
        return Variable(t)

    def init_rand(self, bsz=1):
        t = torch.FloatTensor(self.num_layers, bsz, self.hidden_size)
        t.uniform_(0.05)
        if self.use_cuda:
            t = t.cuda(async=True)
        return Variable(t)

    def observe(self, observation):
        observation = copy.deepcopy(observation)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def update(self, xs, ys):
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs).t()
        h0 = self.init_zeros(batchsize)
        _output, hn = self.encoder(xes, h0)

        # start with EOS tensor for all
        x = self.EOS_TENSOR
        if self.use_cuda:
            x = x.cuda(async=True)
        x = Variable(x)
        xe = self.lt(x).unsqueeze(1)
        xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        output_lines = [[] for _ in range(batchsize)]

        self.zero_grad()
        # update model
        loss = 0
        self.longest_label = max(self.longest_label, ys.size(1))
        for i in range(ys.size(1)):
            output, hn = self.decoder(xes, hn)
            preds, scores = self.hidden_to_idx(output, drop=True)
            y = ys.select(1, i)
            loss += self.criterion(scores, y)
            # use the true token as the next input
            xes = self.lt(y).unsqueeze(0)
            # hn = self.dropout(hn)
            for j in range(preds.size(0)):
                token = self.v2t([preds.data[j][0]])
                output_lines[j].append(token)

        loss.backward()
        self.update_params()

        if random.random() < 0.1:
            true = self.v2t(ys.data[0])
            #print('loss:', round(loss.data[0], 2),
            #      ' '.join(output_lines[0]), '(true: {})'.format(true))
        return output_lines

    def predict(self, xs):
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs).t()
        h0 = self.init_zeros(batchsize)
        _output, hn = self.encoder(xes, h0)

        # start with EOS tensor for all
        x = self.EOS_TENSOR
        if self.use_cuda:
            x = x.cuda(async=True)
        x = Variable(x)
        xe = self.lt(x).unsqueeze(1)
        xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0
        output_lines = [[] for _ in range(batchsize)]

        while (total_done < batchsize) and max_len < self.longest_label:
            output, hn = self.decoder(xes, hn)
            preds, scores = self.hidden_to_idx(output, drop=False)
            xes = self.lt(preds.t())
            max_len += 1
            for i in range(preds.size(0)):
                if not done[i]:
                    token = self.v2t(preds.data[i])
                    if token == self.EOS:
                        done[i] = True
                        total_done += 1
                    else:
                        output_lines[i].append(token)
        if random.random() < 0.1:
            print('prediction:', ' '.join(output_lines[0]))
        return output_lines

    def batchify(self, obs):
        exs = [ex for ex in obs if 'text' in ex]
        valid_inds = [i for i, ex in enumerate(obs) if 'text' in ex]

        batchsize = len(exs)
        parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max([len(x) for x in parsed])
        xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
        for i, x in enumerate(parsed):
            offset = max_x_len - len(x)
            for j, idx in enumerate(x):
                xs[i][j + offset] = idx
        if self.use_cuda:
            xs = xs.cuda(async=True)
        xs = Variable(xs)

        ys = None
        if 'labels' in exs[0]:
            labels = [
                random.choice(ex['labels']) + ' ' + self.EOS for ex in exs
            ]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                ys = ys.cuda(async=True)
            ys = Variable(ys)
        return xs, ys, valid_inds

    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        xs, ys, valid_inds = self.batchify(observations)

        if len(xs) == 0:
            return batch_reply

        # Either train or predict
        if ys is not None:
            predictions = self.update(xs, ys)
        else:
            predictions = self.predict(xs)

        for i in range(len(predictions)):
            batch_reply[valid_inds[i]]['text'] = ' '.join(
                c for c in predictions[i] if c != self.EOS)

        return batch_reply

    def act(self):
        return self.batch_act([self.observation])[0]

    def save(self, path):
        model = {}
        model['lt'] = self.lt.state_dict()
        model['encoder'] = self.encoder.state_dict()
        model['decoder'] = self.decoder.state_dict()
        model['d2o'] = self.d2o.state_dict()
        model['longest_label'] = self.longest_label

        with open(path, 'wb') as write:
            torch.save(model, write)

    def load(self, path):
        with open(path, 'rb') as read:
            model = torch.load(read)

        self.lt.load_state_dict(model['lt'])
        self.encoder.load_state_dict(model['encoder'])
        self.decoder.load_state_dict(model['decoder'])
        self.d2o.load_state_dict(model['d2o'])
        self.longest_label = model['longest_label']
コード例 #6
0
ファイル: memnn.py プロジェクト: vigneshkalai/ParlAI
class MemnnAgent(Agent):
    """ Memory Network agent.
    """
    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        arg_group = argparser.add_argument_group('MemNN Arguments')
        arg_group.add_argument(
            '--init-model',
            type=str,
            default=None,
            help='load dict/features/weights/opts from this file')
        arg_group.add_argument('-lr',
                               '--learning-rate',
                               type=float,
                               default=0.01,
                               help='learning rate')
        arg_group.add_argument('--embedding-size',
                               type=int,
                               default=128,
                               help='size of token embeddings')
        arg_group.add_argument('--hops',
                               type=int,
                               default=3,
                               help='number of memory hops')
        arg_group.add_argument('--mem-size',
                               type=int,
                               default=100,
                               help='size of memory')
        arg_group.add_argument('--time-features',
                               type='bool',
                               default=True,
                               help='use time features for memory embeddings')
        arg_group.add_argument(
            '--position-encoding',
            type='bool',
            default=False,
            help='use position encoding instead of bag of words embedding')
        arg_group.add_argument('--output',
                               type=str,
                               default='rank',
                               help='type of output (rank|generate)')
        arg_group.add_argument(
            '--rnn-layers',
            type=int,
            default=2,
            help='number of hidden layers in RNN decoder for generative output'
        )
        arg_group.add_argument(
            '--dropout',
            type=float,
            default=0.1,
            help='dropout probability for RNN decoder training')
        arg_group.add_argument('--optimizer',
                               default='adam',
                               help='optimizer type (sgd|adam)')
        arg_group.add_argument('--no-cuda',
                               action='store_true',
                               default=False,
                               help='disable GPUs even if available')
        arg_group.add_argument('--gpu',
                               type=int,
                               default=-1,
                               help='which GPU device to use')
        arg_group.add_argument('-histr',
                               '--history-replies',
                               default='label',
                               type=str,
                               choices=['none', 'model', 'label'],
                               help='Keep replies in the history, or not.')

    def __init__(self, opt, shared=None):
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.device(opt['gpu'])

        if not shared:
            self.id = 'MemNN'
            self.dict = DictionaryAgent(opt)
            self.answers = [None] * opt['batchsize']
            self.model = MemNN(opt, len(self.dict))

        else:
            self.dict = shared['dict']
            # model is shared during hogwild
            if 'threadindex' in shared:
                self.model = shared['model']
                self.decoder = shared['decoder']
                self.answers = [None] * opt['batchsize']
            else:
                self.answers = shared['answers']

        if hasattr(self, 'model'):
            self.opt = opt
            self.mem_size = opt['mem_size']
            self.loss_fn = CrossEntropyLoss()

            self.decoder = None
            self.longest_label = 1
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))

            if opt['output'] == 'generate' or opt['output'] == 'g':
                self.decoder = Decoder(opt['embedding_size'],
                                       opt['embedding_size'],
                                       opt['rnn_layers'], opt, self.dict)
            if opt['cuda'] and not shared:
                self.model.share_memory()
                if self.decoder is not None:
                    self.decoder.cuda()

            elif opt['output'] != 'rank' and opt['output'] != 'r':
                raise NotImplementedError('Output type not supported.')

            optim_params = [
                p for p in self.model.parameters() if p.requires_grad
            ]
            lr = opt['learning_rate']
            if opt['optimizer'] == 'sgd':
                self.optimizers = {'memnn': optim.SGD(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.SGD(
                        self.decoder.parameters(), lr=lr)
            elif opt['optimizer'] == 'adam':
                self.optimizers = {'memnn': optim.Adam(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.Adam(
                        self.decoder.parameters(), lr=lr)
            else:
                raise NotImplementedError('Optimizer not supported.')

            # check first for 'init_model' for loading model from file
            if opt.get('init_model') and os.path.isfile(opt['init_model']):
                init_model = opt['init_model']
            # next check for 'model_file'
            elif opt.get('model_file') and os.path.isfile(opt['model_file']):
                init_model = opt['model_file']
            else:
                init_model = None
            if init_model is not None:
                print('Loading existing model parameters from ' + init_model)
                self.load(init_model)

        self.history = {}
        self.episode_done = True
        self.last_cands, self.last_cands_list = None, None
        super().__init__(opt, shared)

    def share(self):
        shared = super().share()
        shared['answers'] = self.answers
        shared['dict'] = self.dict
        if self.opt.get('numthreads', 1) > 1:
            shared['model'] = self.model
            self.model.share_memory()
            shared['decoder'] = self.decoder
        return shared

    def observe(self, observation):
        """Save observation for act.
        If multiple observations are from the same episode, concatenate them.
        """
        self.episode_done = observation['episode_done']
        # shallow copy observation (deep copy can be expensive)
        obs = observation.copy()
        batch_idx = self.opt.get('batchindex', 0)

        obs['text'] = (maintain_dialog_history(
            self.history,
            obs,
            reply=self.answers[batch_idx]
            if self.answers[batch_idx] is not None else '',
            historyLength=self.opt['mem_size'] + 1,
            useReplies=self.opt['history_replies'],
            dict=self.dict,
            useStartEndIndices=False,
            splitSentences=True))

        self.observation = obs
        self.answers[batch_idx] = None
        return obs

    def predict(self, xs, cands, ys=None):
        is_training = ys is not None
        if is_training:
            # Subsample to reduce training time
            cands = [
                list(set(random.sample(c, min(32, len(c))) + self.labels))
                for c in cands
            ]
        else:
            # rank all cands to increase accuracy
            cands = [list(set(c)) for c in cands]

        self.model.train(mode=is_training)
        # Organize inputs for network (see contents of xs and ys in batchify method)
        inputs = [Variable(x, volatile=is_training) for x in xs]
        output_embeddings = self.model(*inputs)

        if self.decoder is None:
            scores = self.score(cands, output_embeddings)
            if is_training:
                label_inds = [
                    cand_list.index(self.labels[i])
                    for i, cand_list in enumerate(cands)
                ]
                if self.opt['cuda']:
                    label_inds = Variable(torch.cuda.LongTensor(label_inds))
                else:
                    label_inds = Variable(torch.LongTensor(label_inds))
                loss = self.loss_fn(scores, label_inds)
            predictions = self.ranked_predictions(cands, scores)
        else:
            self.decoder.train(mode=is_training)

            output_lines, loss = self.decode(output_embeddings, ys)
            predictions = self.generated_predictions(output_lines)

        if is_training:
            for o in self.optimizers.values():
                o.zero_grad()
            loss.backward()
            for o in self.optimizers.values():
                o.step()
        return predictions

    def score(self, cands, output_embeddings):
        last_cand = None
        max_len = max([len(c) for c in cands])
        scores = Variable(output_embeddings.data.new(len(cands), max_len))
        for i, cand_list in enumerate(cands):
            if last_cand != cand_list:
                candidate_lengths, candidate_indices = to_tensors(
                    cand_list, self.dict)
                candidate_lengths, candidate_indices = Variable(
                    candidate_lengths), Variable(candidate_indices)
                candidate_embeddings = self.model.answer_embedder(
                    candidate_lengths, candidate_indices)
                if self.opt['cuda']:
                    candidate_embeddings = candidate_embeddings.cuda()
                last_cand = cand_list
            scores[i, :len(cand_list)] = self.model.score.one_to_many(
                output_embeddings[i].unsqueeze(0),
                candidate_embeddings).squeeze()
        return scores

    def ranked_predictions(self, cands, scores):
        # return [' '] * len(self.answers)
        _, inds = scores.data.sort(descending=True, dim=1)
        return [[cands[i][j] for j in r if j < len(cands[i])]
                for i, r in enumerate(inds)]

    def decode(self, output_embeddings, ys=None):
        batchsize = output_embeddings.size(0)
        hn = output_embeddings.unsqueeze(0).expand(self.opt['rnn_layers'],
                                                   batchsize,
                                                   output_embeddings.size(1))
        x = self.model.answer_embedder(Variable(torch.LongTensor([1])),
                                       Variable(self.START_TENSOR))
        xes = x.unsqueeze(1).expand(x.size(0), batchsize, x.size(1))

        loss = 0
        output_lines = [[] for _ in range(batchsize)]
        done = [False for _ in range(batchsize)]
        total_done = 0
        idx = 0
        while (total_done < batchsize) and idx < self.longest_label:
            # keep producing tokens until we hit END or max length for each ex
            if self.opt['cuda']:
                xes = xes.cuda()
                hn = hn.contiguous()
            preds, scores = self.decoder(xes, hn)
            if ys is not None:
                y = Variable(ys[0][:, idx])
                temp_y = y.cuda() if self.opt['cuda'] else y
                loss += self.loss_fn(scores, temp_y)
            else:
                y = preds
            # use the true token as the next input for better training
            xes = self.model.answer_embedder(
                Variable(torch.LongTensor(preds.numel()).fill_(1)),
                y).unsqueeze(0)

            for b in range(batchsize):
                if not done[b]:
                    token = self.dict.vec2txt(preds.data[b])
                    if token == self.END:
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)
            idx += 1
        return output_lines, loss

    def generated_predictions(self, output_lines):
        return [[
            ' '.join(c for c in o
                     if c != self.END and c != self.dict.null_token)
        ] for o in output_lines]

    def parse(self, memory):
        """Returns:
            query = tensor (vector) of token indices for query
            query_length = length of query
            memory = tensor (matrix) where each row contains token indices for a memory
            memory_lengths = tensor (vector) with lengths of each memory
        """
        query = memory.pop()
        query = torch.LongTensor(query)
        query_length = torch.LongTensor([len(query)])

        if len(memory) == 0:
            memory.append(self.dict.null_token)

        memory = [torch.LongTensor(m) for m in memory]
        memory_lengths = torch.LongTensor([len(m) for m in memory])
        memory = torch.cat(memory)
        return (query, memory, query_length, memory_lengths)

    def batchify(self, obs):
        """Returns:
            xs = [memories, queries, memory_lengths, query_lengths]
            ys = [labels, label_lengths] (if available, else None)
            cands = list of candidates for each example in batch
            valid_inds = list of indices for examples with valid observations
        """
        exs = [ex for ex in obs if 'text' in ex and len(ex['text']) > 0]
        valid_inds = [
            i for i, ex in enumerate(obs)
            if 'text' in ex and len(ex['text']) > 0
        ]
        if not exs:
            return [None] * 4

        parsed = [self.parse(ex['text']) for ex in exs]
        queries = torch.cat([x[0] for x in parsed])
        memories = torch.cat([x[1] for x in parsed])
        query_lengths = torch.cat([x[2] for x in parsed])
        memory_lengths = torch.LongTensor(len(exs), self.mem_size).zero_()
        for i in range(len(exs)):
            if len(parsed[i][3]) > 0:
                memory_lengths[i, -len(parsed[i][3]):] = parsed[i][3]
        xs = [memories, queries, memory_lengths, query_lengths]

        ys = None
        self.labels = [
            random.choice(ex['labels']) for ex in exs if 'labels' in ex
        ]
        if len(self.labels) == len(exs):
            parsed = [self.dict.txt2vec(l) for l in self.labels]
            parsed = [torch.LongTensor(p) for p in parsed]
            label_lengths = torch.LongTensor([len(p)
                                              for p in parsed]).unsqueeze(1)
            self.longest_label = max(self.longest_label, label_lengths.max())
            padded = [
                torch.cat(
                    (p, torch.LongTensor(self.longest_label - len(p)).fill_(
                        self.END_TENSOR[0]))) for p in parsed
            ]
            labels = torch.stack(padded)
            ys = [labels, label_lengths]

        cands = [
            ex['label_candidates'] for ex in exs if 'label_candidates' in ex
        ]
        # Use words in dict as candidates if no candidates are provided
        if len(cands) < len(exs):
            cands = build_cands(exs, self.dict)
        # Avoid rebuilding candidate list every batch if its the same
        if self.last_cands != cands:
            self.last_cands = cands
            self.last_cands_list = [list(c) for c in cands]
        cands = self.last_cands_list
        return xs, ys, cands, valid_inds

    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        xs, ys, cands, valid_inds = self.batchify(observations)

        if xs is None or len(xs[1]) == 0:
            return batch_reply

        # Either train or predict
        predictions = self.predict(xs, cands, ys)

        for i in range(len(valid_inds)):
            self.answers[valid_inds[i]] = predictions[i][0]
            batch_reply[valid_inds[i]]['text'] = predictions[i][0]
            batch_reply[valid_inds[i]]['text_candidates'] = predictions[i]
        return batch_reply

    def act(self):
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['memnn'] = self.model.state_dict()
            checkpoint['memnn_optim'] = self.optimizers['memnn'].state_dict()
            if self.decoder is not None:
                checkpoint['decoder'] = self.decoder.state_dict()
                checkpoint['decoder_optim'] = self.optimizers[
                    'decoder'].state_dict()
                checkpoint['longest_label'] = self.longest_label
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)

    def load(self, path):
        with open(path, 'rb') as read:
            checkpoint = torch.load(read)
        self.model.load_state_dict(checkpoint['memnn'])
        self.optimizers['memnn'].load_state_dict(checkpoint['memnn_optim'])
        if self.decoder is not None:
            self.decoder.load_state_dict(checkpoint['decoder'])
            self.optimizers['decoder'].load_state_dict(
                checkpoint['decoder_optim'])
            self.longest_label = checkpoint['longest_label']
コード例 #7
0
class Seq2seqAgent(Agent):
    """Simple agent which uses an RNN to process incoming text observations.
    The RNN generates a vector which is used to represent the input text,
    conditioning on the context to generate an output token-by-token.

    For more information, see Sequence to Sequence Learning with Neural Networks
    `(Sutskever et al. 2014) <https://arxiv.org/abs/1409.3215>`_.
    """
    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs',
                           '--hiddensize',
                           type=int,
                           default=128,
                           help='size of the hidden layers and embeddings')
        agent.add_argument('-nl',
                           '--numlayers',
                           type=int,
                           default=2,
                           help='number of hidden layers')
        agent.add_argument('-lr',
                           '--learningrate',
                           type=float,
                           default=0.5,
                           help='learning rate')
        agent.add_argument('-dr',
                           '--dropout',
                           type=float,
                           default=0.1,
                           help='dropout rate')
        # agent.add_argument('-bi', '--bidirectional', type='bool', default=False,
        #     help='whether to encode the context with a bidirectional RNN')
        agent.add_argument('--no-cuda',
                           action='store_true',
                           default=False,
                           help='disable GPUs even if available')
        agent.add_argument('--gpu',
                           type=int,
                           default=-1,
                           help='which GPU device to use')
        agent.add_argument(
            '-r',
            '--rank-candidates',
            type='bool',
            default=False,
            help='rank candidates if available. this is done by computing the'
            + ' mean score per token for each candidate and selecting the ' +
            'highest scoring one.')

    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            # check for cuda
            self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available(
            )
            if self.use_cuda:
                print('[ Using CUDA ]')
                torch.cuda.set_device(opt['gpu'])

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, self.states = self.load(opt['model_file'])
                # override options with stored ones
                opt = self.override_opt(new_opt)

            self.dict = DictionaryAgent(opt)
            self.id = 'Seq2Seq'
            # we use END markers to break input and output and end our output
            self.END = self.dict.end_token
            self.observation = {'text': self.END, 'episode_done': True}
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            # get index of null token from dictionary (probably 0)
            self.NULL_IDX = self.dict.txt2vec(self.dict.null_token)[0]

            # store important params directly
            hsz = opt['hiddensize']
            self.hidden_size = hsz
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learningrate']
            self.rank = opt['rank_candidates']
            self.longest_label = 1

            # set up modules
            self.criterion = nn.NLLLoss()
            # lookup table stores word embeddings
            self.lt = nn.Embedding(len(self.dict),
                                   hsz,
                                   padding_idx=self.NULL_IDX,
                                   scale_grad_by_freq=True)
            # encoder captures the input text
            self.encoder = nn.GRU(hsz, hsz, opt['numlayers'])
            # decoder produces our output states
            self.decoder = nn.GRU(hsz, hsz, opt['numlayers'])
            # linear layer helps us produce outputs from final decoder state
            self.h2o = nn.Linear(hsz, len(self.dict))
            # droput on the linear layer helps us generalize
            self.dropout = nn.Dropout(opt['dropout'])
            # softmax maps output scores to probabilities
            self.softmax = nn.LogSoftmax()

            # set up optims for each module
            lr = opt['learningrate']
            self.optims = {
                'lt': optim.SGD(self.lt.parameters(), lr=lr),
                'encoder': optim.SGD(self.encoder.parameters(), lr=lr),
                'decoder': optim.SGD(self.decoder.parameters(), lr=lr),
                'h2o': optim.SGD(self.h2o.parameters(), lr=lr),
            }

            if hasattr(self, 'states'):
                # set loaded states if applicable
                self.set_states(self.states)

            if self.use_cuda:
                self.cuda()

        self.episode_done = True

    def override_opt(self, new_opt):
        """Print out each added key and each overriden key."""
        for k, v in new_opt.items():
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def parse(self, text):
        return self.dict.txt2vec(text)

    def v2t(self, vec):
        return self.dict.vec2txt(vec)

    def cuda(self):
        self.END_TENSOR = self.END_TENSOR.cuda(async=True)
        self.criterion.cuda()
        self.lt.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        self.h2o.cuda()
        self.dropout.cuda()
        self.softmax.cuda()

    def hidden_to_idx(self, hidden, dropout=False):
        """Converts hidden state vectors into indices into the dictionary."""
        if hidden.size(0) > 1:
            raise RuntimeError('bad dimensions of tensor:', hidden)
        hidden = hidden.squeeze(0)
        scores = self.h2o(hidden)
        if dropout:
            scores = self.dropout(scores)
        scores = self.softmax(scores)
        _max_score, idx = scores.max(1)
        return idx, scores

    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()

    def init_zeros(self, bsz=1):
        t = torch.zeros(self.num_layers, bsz, self.hidden_size)
        if self.use_cuda:
            t = t.cuda(async=True)
        return Variable(t)

    def init_rand(self, bsz=1):
        t = torch.FloatTensor(self.num_layers, bsz, self.hidden_size)
        t.uniform_(0.05)
        if self.use_cuda:
            t = t.cuda(async=True)
        return Variable(t)

    def observe(self, observation):
        observation = copy.deepcopy(observation)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def predict(self, xs, ys=None, cands=None):
        """Produce a prediction from our model. Update the model using the
        targets if available.
        """
        batchsize = len(xs)
        text_cand_inds = None

        # first encode context
        xes = self.lt(xs).t()
        h0 = self.init_zeros(batchsize)
        _output, hn = self.encoder(xes, h0)

        # next we use END as an input to kick off our decoder
        x = Variable(self.END_TENSOR)
        xe = self.lt(x).unsqueeze(1)
        xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        # list of output tokens for each example in the batch
        output_lines = [[] for _ in range(batchsize)]

        if ys is not None:
            # update the model based on the labels
            self.zero_grad()
            loss = 0
            # keep track of longest label we've ever seen
            self.longest_label = max(self.longest_label, ys.size(1))
            for i in range(ys.size(1)):
                output, hn = self.decoder(xes, hn)
                preds, scores = self.hidden_to_idx(output, dropout=True)
                y = ys.select(1, i)
                loss += self.criterion(scores, y)
                # use the true token as the next input instead of predicted
                # this produces a biased prediction but better training
                xes = self.lt(y).unsqueeze(0)
                for b in range(batchsize):
                    # convert the output scores to tokens
                    token = self.v2t([preds.data[b][0]])
                    output_lines[b].append(token)

            loss.backward()
            self.update_params()

            if random.random() < 0.1:
                # sometimes output a prediction for debugging
                print('prediction:', ' '.join(output_lines[0]), '\nlabel:',
                      self.dict.vec2txt(ys.data[0]))
        else:
            # just produce a prediction without training the model
            done = [False for _ in range(batchsize)]
            total_done = 0
            max_len = 0

            if cands:
                # score each candidate separately

                # cands are exs_with_cands x cands_per_ex x words_per_cand
                # cview is total_cands x words_per_cand
                cview = cands.view(-1, cands.size(2))
                cands_xes = xe.expand(xe.size(0), cview.size(0), xe.size(2))
                sz = hn.size()
                cands_hn = (hn.view(sz[0], sz[1], 1, sz[2]).expand(
                    sz[0], sz[1], cands.size(1),
                    sz[2]).contiguous().view(sz[0], -1, sz[2]))

                cand_scores = torch.zeros(cview.size(0))
                cand_lengths = torch.LongTensor(cview.size(0)).fill_(0)
                if self.use_cuda:
                    cand_scores = cand_scores.cuda(async=True)
                    cand_lengths = cand_lengths.cuda(async=True)
                cand_scores = Variable(cand_scores)
                cand_lengths = Variable(cand_lengths)

                for i in range(cview.size(1)):
                    output, cands_hn = self.decoder(cands_xes, cands_hn)
                    preds, scores = self.hidden_to_idx(output, dropout=False)
                    cs = cview.select(1, i)
                    non_nulls = cs.ne(self.NULL_IDX)
                    cand_lengths += non_nulls.long()
                    score_per_cand = torch.gather(scores, 1, cs.unsqueeze(1))
                    cand_scores += score_per_cand.squeeze() * non_nulls.float()
                    cands_xes = self.lt(cs).unsqueeze(0)

                # set empty scores to -1, so when divided by 0 they become -inf
                cand_scores -= cand_lengths.eq(0).float()
                # average the scores per token
                cand_scores /= cand_lengths.float()

                cand_scores = cand_scores.view(cands.size(0), cands.size(1))
                srtd_scores, text_cand_inds = cand_scores.sort(1, True)
                text_cand_inds = text_cand_inds.data

            # now, generate a response from scratch
            while (total_done < batchsize) and max_len < self.longest_label:
                # keep producing tokens until we hit END or max length for each
                # example in the batch
                output, hn = self.decoder(xes, hn)
                preds, scores = self.hidden_to_idx(output, dropout=False)

                xes = self.lt(preds.t())
                max_len += 1
                for b in range(batchsize):
                    if not done[b]:
                        # only add more tokens for examples that aren't done yet
                        token = self.v2t(preds.data[b])
                        if token == self.END:
                            # if we produced END, we're done
                            done[b] = True
                            total_done += 1
                        else:
                            output_lines[b].append(token)

            if random.random() < 0.1:
                # sometimes output a prediction for debugging
                print('prediction:', ' '.join(output_lines[0]))

        return output_lines, text_cand_inds

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        # tokenize the text
        xs = None
        if batchsize > 0:
            parsed = [self.parse(ex['text']) for ex in exs]
            max_x_len = max([len(x) for x in parsed])
            xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
            # pack the data to the right side of the tensor for this model
            for i, x in enumerate(parsed):
                offset = max_x_len - len(x)
                for j, idx in enumerate(x):
                    xs[i][j + offset] = idx
            if self.use_cuda:
                xs = xs.cuda(async=True)
            xs = Variable(xs)

        # set up the target tensors
        ys = None
        if batchsize > 0 and any(['labels' in ex for ex in exs]):
            # randomly select one of the labels to update on, if multiple
            # append END to each label
            labels = [
                random.choice(ex.get('labels', [''])) + ' ' + self.END
                for ex in exs
            ]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                ys = ys.cuda(async=True)
            ys = Variable(ys)

        # set up candidates
        cands = None
        valid_cands = None
        if ys is None and self.rank:
            # only do ranking when no targets available and ranking flag set
            parsed = []
            valid_cands = []
            for i in valid_inds:
                if 'label_candidates' in observations[i]:
                    # each candidate tuple is a pair of the parsed version and
                    # the original full string
                    cs = list(observations[i]['label_candidates'])
                    parsed.append([self.parse(c) for c in cs])
                    valid_cands.append((i, cs))
            if len(parsed) > 0:
                # TODO: store lengths of cands separately, so don't have zero
                # padding for varying number of cands per example
                # found cands, pack them into tensor
                max_c_len = max(max(len(c) for c in cs) for cs in parsed)
                max_c_cnt = max(len(cs) for cs in parsed)
                cands = torch.LongTensor(len(parsed), max_c_cnt,
                                         max_c_len).fill_(0)
                for i, cs in enumerate(parsed):
                    for j, c in enumerate(cs):
                        for k, idx in enumerate(c):
                            cands[i][j][k] = idx
                if self.use_cuda:
                    cands = cands.cuda(async=True)
                cands = Variable(cands)

        return xs, ys, valid_inds, cands, valid_cands

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, ys, valid_inds, cands, valid_cands = self.batchify(observations)

        if xs is None:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions either way, but use the targets if available
        predictions, text_cand_inds = self.predict(xs, ys, cands)

        for i in range(len(predictions)):
            # map the predictions back to non-empty examples in the batch
            # we join with spaces since we produce tokens one at a time
            curr = batch_reply[valid_inds[i]]
            curr['text'] = ' '.join(
                c for c in predictions[i]
                if c != self.END and c != self.dict.null_token)

        if text_cand_inds is not None:
            for i in range(len(valid_cands)):
                order = text_cand_inds[i]
                batch_idx, curr_cands = valid_cands[i]
                curr = batch_reply[batch_idx]
                curr['text_candidates'] = [
                    curr_cands[idx] for idx in order if idx < len(curr_cands)
                ]

        return batch_reply

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            model = {}
            model['lt'] = self.lt.state_dict()
            model['encoder'] = self.encoder.state_dict()
            model['decoder'] = self.decoder.state_dict()
            model['h2o'] = self.h2o.state_dict()
            model['longest_label'] = self.longest_label
            model['opt'] = self.opt

            with open(path, 'wb') as write:
                torch.save(model, write)

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)

        return model['opt'], model

    def set_states(self, states):
        """Set the state dicts of the modules from saved states."""
        self.lt.load_state_dict(states['lt'])
        self.encoder.load_state_dict(states['encoder'])
        self.decoder.load_state_dict(states['decoder'])
        self.h2o.load_state_dict(states['h2o'])
        self.longest_label = states['longest_label']
コード例 #8
0
ファイル: seq2seq.py プロジェクト: zhaojunzuozjzfr/ParlAI
class Seq2seqAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    This model supports encoding the input and decoding the output via one of
    several flavors of RNN. It then uses a linear layer (whose weights can
    be shared with the embedding layer) to convert RNN output states into
    output tokens. This model currently uses greedy decoding, selecting the
    highest probability token at each time step.

    For more information, see Sequence to Sequence Learning with Neural
    Networks `(Sutskever et al. 2014) <https://arxiv.org/abs/1409.3215>`_.
    """

    OPTIM_OPTS = {
        'adadelta': optim.Adadelta,
        'adagrad': optim.Adagrad,
        'adam': optim.Adam,
        'adamax': optim.Adamax,
        'asgd': optim.ASGD,
        'lbfgs': optim.LBFGS,
        'rmsprop': optim.RMSprop,
        'rprop': optim.Rprop,
        'sgd': optim.SGD,
    }

    ENC_OPTS = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs',
                           '--hiddensize',
                           type=int,
                           default=128,
                           help='size of the hidden layers')
        agent.add_argument('-esz',
                           '--embeddingsize',
                           type=int,
                           default=128,
                           help='size of the token embeddings')
        agent.add_argument('-nl',
                           '--numlayers',
                           type=int,
                           default=2,
                           help='number of hidden layers')
        agent.add_argument('-lr',
                           '--learningrate',
                           type=float,
                           default=0.005,
                           help='learning rate')
        agent.add_argument('-dr',
                           '--dropout',
                           type=float,
                           default=0,
                           help='dropout rate')
        agent.add_argument('-bi',
                           '--bidirectional',
                           type='bool',
                           default=False,
                           help='whether to encode the context with a '
                           'bidirectional rnn')
        agent.add_argument('-att',
                           '--attention',
                           default='none',
                           choices=['none', 'concat', 'general', 'local'],
                           help='Choices: none, concat, general, local. '
                           'If set local, also set attention-length. '
                           'For more details see: '
                           'https://arxiv.org/pdf/1508.04025.pdf')
        agent.add_argument('-attl',
                           '--attention-length',
                           default=48,
                           type=int,
                           help='Length of local attention.')
        agent.add_argument('--no-cuda',
                           action='store_true',
                           default=False,
                           help='disable GPUs even if available')
        agent.add_argument('--gpu',
                           type=int,
                           default=-1,
                           help='which GPU device to use')
        agent.add_argument('-rc',
                           '--rank-candidates',
                           type='bool',
                           default=False,
                           help='rank candidates if available. this is done by'
                           ' computing the mean score per token for each '
                           'candidate and selecting the highest scoring.')
        agent.add_argument('-tr',
                           '--truncate',
                           type=int,
                           default=-1,
                           help='truncate input & output lengths to speed up '
                           'training (may reduce accuracy). This fixes all '
                           'input and output to have a maximum length and to '
                           'be similar in length to one another by throwing '
                           'away extra tokens. This reduces the total amount '
                           'of padding in the batches.')
        agent.add_argument('-enc',
                           '--encoder',
                           default='gru',
                           choices=Seq2seqAgent.ENC_OPTS.keys(),
                           help='Choose between different encoder modules.')
        agent.add_argument('-dec',
                           '--decoder',
                           default='same',
                           choices=['same', 'shared'] +
                           list(Seq2seqAgent.ENC_OPTS.keys()),
                           help='Choose between different decoder modules. '
                           'Default "same" uses same class as encoder, '
                           'while "shared" also uses the same weights. '
                           'Note that shared disabled some encoder '
                           'options--in particular, bidirectionality.')
        agent.add_argument('-lt',
                           '--lookuptable',
                           default='all',
                           choices=['unique', 'enc_dec', 'dec_out', 'all'],
                           help='The encoder, decoder, and output modules can '
                           'share weights, or not. '
                           'Unique has independent embeddings for each. '
                           'Enc_dec shares the embedding for the encoder '
                           'and decoder. '
                           'Dec_out shares decoder embedding and output '
                           'weights. '
                           'All shares all three weights.')
        agent.add_argument('-opt',
                           '--optimizer',
                           default='adam',
                           choices=Seq2seqAgent.OPTIM_OPTS.keys(),
                           help='Choose between pytorch optimizers. '
                           'Any member of torch.optim is valid and will '
                           'be used with default params except learning '
                           'rate (as specified by -lr).')
        agent.add_argument('-emb',
                           '--embedding-init',
                           default='random',
                           choices=['random', 'glove'],
                           help='Choose between initialization strategies '
                           'for word embeddings. Default is random, '
                           'but can also preinitialize from Glove')
        agent.add_argument('-lm',
                           '--language-model',
                           type='bool',
                           default=False,
                           help='enabled language modeling training on the '
                           'concatenated input and label data')

    def __init__(self, opt, shared=None):
        """Set up model if shared params not set, otherwise no work to do."""
        super().__init__(opt, shared)
        if shared:
            # set up shared properties
            # answers contains a batch_size list of the last answer produced
            self.answers = shared['answers']
            # start token
            self.START = shared['START']
            # end token
            self.END = shared['END']
        else:
            # this is not a shared instance of this class, so do full init

            # answers contains a batch_size list of the last answer produced
            self.answers = [None] * opt['batchsize']

            # check for cuda
            self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available(
            )
            if self.use_cuda:
                print('[ Using CUDA ]')
                torch.cuda.set_device(opt['gpu'])

            states = None
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, states = self.load(opt['model_file'])
                # override model-specific options with stored ones
                opt = self.override_opt(new_opt)

            if opt['dict_file'] is None and opt.get('model_file'):
                # set default dict-file if not set
                opt['dict_file'] = opt['model_file'] + '.dict'

            # load dictionary and basic tokens & vectors
            self.dict = DictionaryAgent(opt)
            self.id = 'Seq2Seq'
            # we use START markers to start our output
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))
            # we use END markers to end our output
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            # get index of null token from dictionary (probably 0)
            self.NULL_IDX = self.dict.txt2vec(self.dict.null_token)[0]

            # store important params in self
            hsz = opt['hiddensize']
            emb = opt['embeddingsize']
            self.hidden_size = hsz
            self.emb_size = emb
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learningrate']
            self.rank = opt['rank_candidates']
            self.longest_label = 1
            self.truncate = opt['truncate']
            self.attention = opt['attention']
            self.bidirectional = opt['bidirectional']
            self.num_dirs = 2 if self.bidirectional else 1
            self.dropout = opt['dropout']
            self.lm = opt['language_model']

            # set up tensors once
            self.zeros = torch.zeros(self.num_layers * self.num_dirs, 1, hsz)
            self.xs = torch.LongTensor(1, 1)
            self.ys = torch.LongTensor(1, 1)
            if self.rank:
                self.cands = torch.LongTensor(1, 1, 1)
                self.cand_scores = torch.FloatTensor(1)
                self.cand_lengths = torch.LongTensor(1)

            # set up modules
            self.criterion = nn.NLLLoss()
            # lookup table stores word embeddings
            self.enc_lt = nn.Embedding(len(self.dict),
                                       emb,
                                       padding_idx=self.NULL_IDX,
                                       max_norm=10)

            if opt['lookuptable'] in ['enc_dec', 'all']:
                # share this with the encoder
                self.dec_lt = self.enc_lt
            else:
                self.dec_lt = nn.Embedding(len(self.dict),
                                           emb,
                                           padding_idx=self.NULL_IDX,
                                           max_norm=10)

            if not states and opt['embedding_init'] == 'glove':
                # set up pre-initialized vectors from GloVe
                try:
                    import torchtext.vocab as vocab
                except ImportError:
                    raise ImportError('Please install torchtext from'
                                      'github.com/pytorch/text.')
                Glove = vocab.GloVe(name='840B', dim=300)
                # do better than uniform random
                proj = torch.FloatTensor(emb, 300).uniform_(
                    -0.057735, 0.057735) if emb != 300 else None
                for w in self.dict.freq:
                    if w in Glove.stoi:
                        vec = Glove.vectors[Glove.stoi[w]]
                        if emb != 300:
                            vec = torch.mm(proj, vec.unsqueeze(1)).squeeze()
                        self.enc_lt.weight.data[self.dict[w]] = vec
                        self.dec_lt.weight.data[self.dict[w]] = vec

            # encoder captures the input text
            enc_class = Seq2seqAgent.ENC_OPTS[opt['encoder']]
            # decoder produces our output states
            if opt['decoder'] in ['same', 'shared']:
                # use same class as encoder
                self.decoder = enc_class(emb,
                                         hsz,
                                         opt['numlayers'],
                                         dropout=self.dropout,
                                         batch_first=True)
            else:
                # use set class
                dec_class = Seq2seqAgent.ENC_OPTS[opt['decoder']]
                self.decoder = dec_class(emb,
                                         hsz,
                                         opt['numlayers'],
                                         dropout=self.dropout,
                                         batch_first=True)
            if opt['decoder'] == 'shared':
                # shared weights: use the decoder to encode
                if self.bidirectional:
                    raise RuntimeError('Cannot share enc/dec and do '
                                       'bidirectional encoding.')
                self.encoder = self.decoder
            else:
                self.encoder = enc_class(emb,
                                         hsz,
                                         opt['numlayers'],
                                         dropout=self.dropout,
                                         batch_first=True,
                                         bidirectional=self.bidirectional)

            # linear layers help us produce outputs from final decoder state
            hszXdirs = hsz * self.num_dirs
            self.h2e = nn.Linear(hsz, emb)  # hidden to embedding
            self.e2o = nn.Linear(emb, len(self.dict))  # embedding to output
            if opt['lookuptable'] in ['dec_out', 'all']:
                # share these weights with the decoder lookup table
                self.e2o.weight = self.dec_lt.weight

            if self.attention == 'local':
                # local attention over fixed set of output states
                if opt['attention_length'] < 0:
                    raise RuntimeError('Set attention length to > 0.')
                self.max_length = opt['attention_length']
                # combines input and previous hidden output layer
                self.attn = nn.Linear(hsz + emb, self.max_length)
                # combines attention weights with encoder outputs
                self.attn_combine = nn.Linear(hszXdirs + emb, emb)
            elif self.attention == 'concat':
                self.attn = nn.Linear(hsz + hszXdirs, hsz)
                self.attn_v = nn.Linear(hsz, 1)
                self.attn_combine = nn.Linear(hszXdirs + emb, emb)
            elif self.attention == 'general':
                self.attn = nn.Linear(hsz, hszXdirs)
                self.attn_combine = nn.Linear(hszXdirs + emb, emb)

            # set up optims for each module
            lr = opt['learningrate']
            optim_class = Seq2seqAgent.OPTIM_OPTS[opt['optimizer']]
            kwargs = {'lr': lr}
            if opt['optimizer'] == 'sgd':
                kwargs['momentum'] = 0.95
                kwargs['nesterov'] = True
            self.optims = {
                'enc_lt': optim_class(self.enc_lt.parameters(), **kwargs),
                'decoder': optim_class(self.decoder.parameters(), **kwargs),
                'h2e': optim_class(self.h2e.parameters(), **kwargs),
                'e2o': optim_class(self.e2o.parameters(), **kwargs),
            }
            if opt['decoder'] != 'shared':
                self.optims['encoder'] = optim_class(self.encoder.parameters(),
                                                     **kwargs)
            if opt['lookuptable'] not in ['enc_dec', 'all']:
                # only add dec if it's separate from enc
                self.optims['dec_lt'] = optim_class(self.dec_lt.parameters(),
                                                    **kwargs)

            # add attention parameters into optims if available
            for attn_name in ['attn', 'attn_v', 'attn_combine']:
                if hasattr(self, attn_name):
                    self.optims[attn_name] = optim_class(
                        getattr(self, attn_name).parameters(), **kwargs)

            if states is not None:
                # set loaded states if applicable
                self.set_states(states)

            if self.use_cuda:
                self.cuda()

        self.reset()

    def override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'hiddensize', 'embeddingsize', 'numlayers', 'optimizer', 'encoder',
            'decoder', 'lookuptable', 'attention', 'attention_length'
        }
        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def parse(self, text):
        """Convert string to token indices."""
        return self.dict.txt2vec(text)

    def v2t(self, vec):
        """Convert token indices to string of tokens."""
        return self.dict.vec2txt(vec)

    def cuda(self):
        """Push parameters to the GPU."""
        self.START_TENSOR = self.START_TENSOR.cuda(async=True)
        self.END_TENSOR = self.END_TENSOR.cuda(async=True)
        self.zeros = self.zeros.cuda(async=True)
        self.xs = self.xs.cuda(async=True)
        self.ys = self.ys.cuda(async=True)
        if self.rank:
            self.cands = self.cands.cuda(async=True)
            self.cand_scores = self.cand_scores.cuda(async=True)
            self.cand_lengths = self.cand_lengths.cuda(async=True)
        self.criterion.cuda()
        self.enc_lt.cuda()
        self.dec_lt.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        self.h2e.cuda()
        self.e2o.cuda()
        if self.attention != 'none':
            for attn_name in ['attn', 'attn_v', 'attn_combine']:
                if hasattr(self, attn_name):
                    getattr(self, attn_name).cuda()

    def hidden_to_idx(self, hidden, is_training=False):
        """Convert hidden state vectors into indices into the dictionary."""
        # dropout at each step
        e = F.dropout(self.h2e(hidden), p=self.dropout, training=is_training)
        out = F.dropout(self.e2o(e), p=self.dropout, training=is_training)

        # out is batch_size x sequence_length x dict_sz
        if out.size(1) == 1:
            # sequence length is one, just squeeze it so we don't need to cat
            scores = F.log_softmax(out.squeeze(1)).unsqueeze(1)
        else:
            # we need a softmax per token
            # index on argmin(batch_size,seq_length) so fewer cats / bigger ops
            dim = 0 if out.size(0) < out.size(1) else 1
            scores = torch.cat([
                F.log_softmax(out.select(dim, i)).unsqueeze(dim)
                for i in range(out.size(dim))
            ], dim)
        _max_score, idx = scores.max(2)
        return idx, scores

    def zero_grad(self):
        """Zero out optimizers."""
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        """Do one optimization step."""
        for optimizer in self.optims.values():
            optimizer.step()

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def share(self):
        """Share internal states between parent and child instances."""
        shared = super().share()
        shared['answers'] = self.answers
        shared['START'] = self.START
        shared['END'] = self.END
        return shared

    def observe(self, observation):
        """Save observation for act.
        If multiple observations are from the same episode, concatenate them.
        """
        # shallow copy observation (deep copy can be expensive)
        observation = observation.copy()
        if 'text' in observation:
            # put START and END around text
            observation['text'] = '{s} {x} {e}'.format(s=self.START,
                                                       x=observation['text'],
                                                       e=self.END)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            # get last y
            batch_idx = self.opt.get('batchindex', 0)
            if self.answers[batch_idx] is not None:
                # use our last answer, which is the label during training
                lastY = self.answers[batch_idx]
                prev_dialogue = '{p}\n{s} {y} {e}'.format(p=prev_dialogue,
                                                          s=self.START,
                                                          y=lastY,
                                                          e=self.END)
                self.answers[batch_idx] = None  # forget last y
            # add current observation back in
            observation['text'] = '{p}\n{x}'.format(p=prev_dialogue,
                                                    x=observation['text'])
            # final text: <s> lastx </s> \n <s> lasty </s> \n <s> currx </s>
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def _encode(self, xs, is_training=False):
        """Call encoder and return output and hidden states."""
        self.lastxs = xs
        batchsize = len(xs)

        # first encode context
        xes = F.dropout(self.enc_lt(xs), p=self.dropout, training=is_training)
        # project from emb_size to hidden_size dimensions
        x_lens = [x for x in torch.sum((xs > 0).int(), dim=1).data]
        xes_packed = pack_padded_sequence(xes, x_lens, batch_first=True)

        if self.zeros.size(1) != batchsize:
            self.zeros.resize_(self.num_layers * self.num_dirs, batchsize,
                               self.hidden_size).fill_(0)

        h0 = Variable(self.zeros, requires_grad=False)
        if type(self.encoder) == nn.LSTM:
            encoder_output_packed, hidden = self.encoder(xes_packed, (h0, h0))
            # take elementwise max between forward and backward hidden states
            hidden = (hidden[0].view(-1, self.num_dirs, hidden[0].size(1),
                                     hidden[0].size(2)).max(1)[0],
                      hidden[1].view(-1, self.num_dirs, hidden[1].size(1),
                                     hidden[1].size(2)).max(1)[0])
            if type(self.decoder) != nn.LSTM:
                hidden = hidden[0]
        else:
            encoder_output_packed, hidden = self.encoder(xes_packed, h0)

            # take elementwise max between forward and backward hidden states
            hidden = hidden.view(-1, self.num_dirs, hidden.size(1),
                                 hidden.size(2)).max(1)[0]
            if type(self.decoder) == nn.LSTM:
                hidden = (hidden, h0.narrow(0, 0, 2))
        encoder_output, _ = pad_packed_sequence(encoder_output_packed,
                                                batch_first=True)
        encoder_output = encoder_output

        if self.attention == 'local':
            # if using local attention, narrow encoder_output to max_length
            if encoder_output.size(1) > self.max_length:
                offset = encoder_output.size(1) - self.max_length
                encoder_output = encoder_output.narrow(1, offset,
                                                       self.max_length)

        return encoder_output, hidden

    def _apply_attention(self, xes, encoder_output, hidden, attn_mask=None):
        """Apply attention to encoder hidden layer."""
        last_hidden = hidden[-1]  # select hidden from last RNN layer
        if self.attention == 'concat':
            hidden_expand = last_hidden.unsqueeze(1).expand(
                last_hidden.size(0), encoder_output.size(1),
                last_hidden.size(1))
            attn_w_premask = self.attn_v(
                F.tanh(self.attn(torch.cat((encoder_output, hidden_expand),
                                           2)))).squeeze(2)
            attn_weights = F.softmax(attn_w_premask * attn_mask -
                                     (1 - attn_mask) * 1e20)

        elif self.attention == 'general':
            hidden_expand = last_hidden.unsqueeze(1)
            attn_w_premask = torch.bmm(self.attn(hidden_expand),
                                       encoder_output.transpose(1,
                                                                2)).squeeze(1)
            attn_weights = F.softmax(attn_w_premask * attn_mask -
                                     (1 - attn_mask) * 1e20)

        elif self.attention == 'local':
            attn_weights = F.softmax(
                self.attn(torch.cat((xes.squeeze(1), last_hidden), 1)))
            if attn_weights.size(1) > encoder_output.size(1):
                attn_weights = attn_weights.narrow(1, 0,
                                                   encoder_output.size(1))

        attn_applied = torch.bmm(attn_weights.unsqueeze(1),
                                 encoder_output).squeeze(1)

        output = torch.cat((xes.squeeze(1), attn_applied), 1)
        output = self.attn_combine(output).unsqueeze(1)
        output = F.tanh(output)

        return output

    def _decode_and_train(self,
                          batchsize,
                          xes,
                          ys,
                          encoder_output,
                          hidden,
                          attn_mask,
                          lm=False):
        """Update the model based on the labels."""
        self.zero_grad()
        loss = 0

        output_lines = [[] for _ in range(batchsize)]

        # keep track of longest label we've ever seen
        # we'll never produce longer ones than that during prediction
        if not lm:
            self.longest_label = max(self.longest_label, ys.size(1))
        if self.attention != 'none':
            # using attention, produce one token at a time
            for i in range(ys.size(1)):
                h_att = hidden[0] if type(self.decoder) == nn.LSTM else hidden
                output = self._apply_attention(xes, encoder_output, h_att,
                                               attn_mask)
                output, hidden = self.decoder(output, hidden)
                preds, scores = self.hidden_to_idx(output, is_training=True)
                y = ys.select(1, i)
                loss += self.criterion(scores.squeeze(1), y)
                # use the true token as the next input instead of predicted
                xes = self.dec_lt(y).unsqueeze(1)
                xes = F.dropout(xes, p=self.dropout, training=True)
                for b in range(batchsize):
                    # convert the output scores to tokens
                    token = self.v2t(preds.data[b])
                    output_lines[b].append(token)
        else:
            # force the entire sequence at once by feeding in START + y[:-2]
            y_in = ys.narrow(1, 0, ys.size(1) - 1)
            xes = torch.cat([xes, self.dec_lt(y_in)], 1)

            output, hidden = self.decoder(xes, hidden)
            preds, scores = self.hidden_to_idx(output, is_training=True)
            for i in range(ys.size(1)):
                # sum loss per-token
                score = scores.select(1, i)
                y = ys.select(1, i)
                loss += self.criterion(score, y)
            for b in range(batchsize):
                output_lines[b].extend(self.v2t(preds.data[b]).split(' '))
        loss.backward()
        self.update_params()

        if random.random() < 0.1:
            # sometimes output a prediction for debugging
            # print('prediction:', ' '.join(output_lines[0]))
            # print('label:', self.v2t(ys.data[0]))
            print('lm' if lm else '  ', 'loss:', loss.data[0])

        return output_lines

    def _decode_only(self, batchsize, xes, ys, encoder_output, hidden,
                     attn_mask):
        """Just produce a prediction without training the model."""
        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0

        output_lines = [[] for _ in range(batchsize)]

        # generate a response from scratch
        while (total_done < batchsize) and max_len < self.longest_label:
            # keep producing tokens until we hit END or max length for each
            # example in the batch
            if self.attention == 'none':
                output = xes
            else:
                h_att = hidden[0] if type(self.decoder) == nn.LSTM else hidden
                output = self._apply_attention(xes, encoder_output, h_att,
                                               attn_mask)
            output, hidden = self.decoder(output, hidden)
            preds, _scores = self.hidden_to_idx(output, is_training=False)

            xes = self.dec_lt(preds)
            max_len += 1
            for b in range(batchsize):
                if not done[b]:
                    # only add more tokens for examples that aren't done yet
                    token = self.v2t(preds.data[b])
                    if token == self.END:
                        # if we produced END, we're done
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)

        if random.random() < 0.2:
            # sometimes output a prediction for debugging
            print('\nprediction:', ' '.join(output_lines[0]))

        return output_lines

    def _score_candidates(self, cands, cand_inds, start, encoder_output,
                          hidden, attn_mask):
        """Rank candidates by their likelihood according to the decoder."""
        if type(self.decoder) == nn.LSTM:
            hidden, cell = hidden
        # score each candidate separately
        # cands are exs_with_cands x cands_per_ex x words_per_cand
        # cview is total_cands x words_per_cand
        cview = cands.view(-1, cands.size(2))
        c_xes = start.expand(cview.size(0), start.size(0), start.size(1))

        if len(cand_inds) != hidden.size(1):
            # only use hidden state from inputs with associated candidates
            cand_indices = torch.LongTensor([i for i, _, _ in cand_inds])
            if self.use_cuda:
                cand_indices = cand_indices.cuda()
            cand_indices = Variable(cand_indices)
            hidden = hidden.index_select(1, cand_indices)

        sz = hidden.size()
        cands_hn = (hidden.view(sz[0], sz[1], 1, sz[2]).expand(
            sz[0], sz[1], cands.size(1),
            sz[2]).contiguous().view(sz[0], -1, sz[2]))
        if type(self.decoder) == nn.LSTM:
            if len(cand_inds) != cell.size(1):
                # only use cell state from inputs with associated candidates
                cell = cell.index_select(1, cand_indices)
            cands_hn = (cands_hn, cell.view(sz[0], sz[1], 1, sz[2]).expand(
                sz[0], sz[1], cands.size(1),
                sz[2]).contiguous().view(sz[0], -1, sz[2]))

        cand_scores = Variable(
            self.cand_scores.resize_(cview.size(0)).fill_(0))
        cand_lengths = Variable(
            self.cand_lengths.resize_(cview.size(0)).fill_(0))

        if self.attention != 'none':
            # using attention
            sz = encoder_output.size()
            cands_encoder_output = (encoder_output.contiguous().view(
                sz[0], 1, sz[1],
                sz[2]).expand(sz[0], cands.size(1), sz[1],
                              sz[2]).contiguous().view(-1, sz[1], sz[2]))

            msz = attn_mask.size()
            cands_attn_mask = (attn_mask.contiguous().view(
                msz[0], 1,
                msz[1]).expand(msz[0], cands.size(1),
                               msz[1]).contiguous().view(-1, msz[1]))
            for i in range(cview.size(1)):
                # process one token at a time
                h_att = cands_hn[0] if type(
                    self.decoder) == nn.LSTM else cands_hn
                output = self._apply_attention(c_xes, cands_encoder_output,
                                               h_att, cands_attn_mask)
                output, cands_hn = self.decoder(output, cands_hn)
                _preds, scores = self.hidden_to_idx(output, is_training=False)
                cs = cview.select(1, i)
                non_nulls = cs.ne(self.NULL_IDX)
                cand_lengths += non_nulls.long()
                score_per_cand = torch.gather(scores.select(1, i), 1,
                                              cs.unsqueeze(1))
                cand_scores += score_per_cand.squeeze() * non_nulls.float()
                c_xes = self.dec_lt(cs).unsqueeze(1)
        else:
            # process entire sequence at once
            if cview.size(1) > 1:
                # feed in START + cands[:-2]
                cands_in = cview.narrow(1, 0, cview.size(1) - 1)
                c_xes = torch.cat([c_xes, self.dec_lt(cands_in)], 1)
            output, cands_hn = self.decoder(c_xes, cands_hn)
            _preds, scores = self.hidden_to_idx(output, is_training=False)

            for i in range(cview.size(1)):
                # calculate score at each token
                cs = cview.select(1, i)
                non_nulls = cs.ne(self.NULL_IDX)
                cand_lengths += non_nulls.long()
                score_per_cand = torch.gather(scores.select(1, i), 1,
                                              cs.unsqueeze(1))
                cand_scores += score_per_cand.squeeze() * non_nulls.float()

        # set empty scores to -1, so when divided by 0 they become -inf
        cand_scores -= cand_lengths.eq(0).float()
        # average the scores per token
        cand_scores /= cand_lengths.float()

        cand_scores = cand_scores.view(cands.size(0), cands.size(1))
        srtd_scores, text_cand_inds = cand_scores.sort(1, True)
        text_cand_inds = text_cand_inds.data

        return text_cand_inds

    def predict(self, xs, ys=None, cands=None, valid_cands=None, lm=False):
        """Produce a prediction from our model.

        Update the model using the targets if available, otherwise rank
        candidates as well if they are available and param is set.
        """
        batchsize = len(xs)
        text_cand_inds = None
        is_training = ys is not None
        self.encoder.train(mode=is_training)
        self.decoder.train(mode=is_training)
        encoder_output, hidden = self._encode(xs, is_training)

        # next we use START as an input to kick off our decoder
        if not lm:
            x = Variable(self.START_TENSOR, requires_grad=False)
            xe = self.dec_lt(x)
            xe = F.dropout(xe, p=self.dropout, training=is_training)
            xes = xe.expand(batchsize, 1, xe.size(1))
        else:
            # during language_model mode, just start with zeros
            xes = Variable(self.zeros[0].narrow(1, 0,
                                                self.emb_size).unsqueeze(1),
                           requires_grad=False)

        # list of output tokens for each example in the batch
        output_lines = None

        if self.attention == 'none':
            attn_mask = None
        else:
            attn_mask = xs.ne(0).float()

        if is_training:
            output_lines = self._decode_and_train(batchsize,
                                                  xes,
                                                  ys,
                                                  encoder_output,
                                                  hidden,
                                                  attn_mask,
                                                  lm=lm)
        else:
            if cands is not None:
                text_cand_inds = self._score_candidates(
                    cands, valid_cands, xe, encoder_output, hidden, attn_mask)

            output_lines = self._decode_only(batchsize, xes, ys,
                                             encoder_output, hidden, attn_mask)

        return output_lines, text_cand_inds

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        def valid(obs):
            # check if this is an example our model should actually process
            return 'text' in obs and ('labels' in obs or 'eval_labels' in obs)

        # valid examples and their indices
        valid_inds, exs = zip(*[(i, ex) for i, ex in enumerate(observations)
                                if valid(ex)])

        # set up the input tensors
        batchsize = len(exs)
        if batchsize == 0:
            return None, None, None, None, None, None

        # tokenize the text
        parsed = [self.parse(ex['text']) for ex in exs]
        x_lens = [len(x) for x in parsed]
        ind_sorted = sorted(range(len(x_lens)), key=lambda k: -x_lens[k])

        exs = [exs[k] for k in ind_sorted]
        valid_inds = [valid_inds[k] for k in ind_sorted]
        parsed = [parsed[k] for k in ind_sorted]

        max_x_len = max([len(x) for x in parsed])
        if self.truncate > 0:
            # shrink xs to to limit batch computation
            max_x_len = min(max_x_len, self.truncate)
            parsed = [x[-max_x_len:] for x in parsed]
        xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
        # right-padded with zeros
        for i, x in enumerate(parsed):
            for j, idx in enumerate(x):
                xs[i][j] = idx
        if self.use_cuda:
            # copy to gpu
            self.xs.resize_(xs.size())
            self.xs.copy_(xs, async=True)
            xs = Variable(self.xs)
        else:
            xs = Variable(xs)

        # set up the target tensors
        ys = None
        labels = None
        if any(['labels' in ex for ex in exs]):
            # randomly select one of the labels to update on, if multiple
            # append END to each label
            labels = [random.choice(ex.get('labels', [''])) for ex in exs]
            parsed = [self.parse(y + ' ' + self.END) for y in labels if y]
            max_y_len = max(len(y) for y in parsed)
            if self.truncate > 0:
                # shrink ys to to limit batch computation
                max_y_len = min(max_y_len, self.truncate)
                parsed = [y[:max_y_len] for y in parsed]
            ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                # copy to gpu
                self.ys.resize_(ys.size())
                self.ys.copy_(ys, async=True)
                ys = Variable(self.ys)
            else:
                ys = Variable(ys)

        # set up candidates
        cands = None
        valid_cands = None
        if ys is None and self.rank:
            # only do ranking when no targets available and ranking flag set
            parsed = []
            valid_cands = []
            for i, v in enumerate(valid_inds):
                if 'label_candidates' in observations[i]:
                    # each candidate tuple is a pair of the parsed version and
                    # the original full string
                    cs = list(observations[i]['label_candidates'])
                    parsed.append([self.parse(c) for c in cs])
                    valid_cands.append((i, v, cs))
            if len(parsed) > 0:
                # TODO: store lengths of cands separately, so don't have zero
                #       padding for varying number of cands per example
                # found cands, pack them into tensor
                max_c_len = max(max(len(c) for c in cs) for cs in parsed)
                max_c_cnt = max(len(cs) for cs in parsed)
                cands = torch.LongTensor(len(parsed), max_c_cnt,
                                         max_c_len).fill_(0)
                for i, cs in enumerate(parsed):
                    for j, c in enumerate(cs):
                        for k, idx in enumerate(c):
                            cands[i][j][k] = idx
                if self.use_cuda:
                    # copy to gpu
                    self.cands.resize_(cands.size())
                    self.cands.copy_(cands, async=True)
                    cands = Variable(self.cands)
                else:
                    cands = Variable(cands)

        return xs, ys, labels, valid_inds, cands, valid_cands

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, ys, labels, valid_inds, cands, valid_cands = self.batchify(
            observations)

        if xs is None:
            # no valid examples, just return empty responses
            return batch_reply

        # produce predictions, train on targets if available
        predictions, text_cand_inds = self.predict(xs, ys, cands, valid_cands)
        if self.lm and ys is not None:
            # also train on lm task: given "START", predict
            new_obs = [{
                'text':
                self.START,
                'labels': [
                    '{x} {s} {y}'.format(x=obs['text'].replace(self.START, ''),
                                         s=self.START,
                                         y=random.choice(
                                             obs.get('labels', [''])))
                ]
            } for obs in observations]
            xs, ys, _, _, _, _ = self.batchify(new_obs)
            _, _ = self.predict(xs, ys, lm=True)

        for i in range(len(predictions)):
            # map the predictions back to non-empty examples in the batch
            # we join with spaces since we produce tokens one at a time
            curr = batch_reply[valid_inds[i]]
            curr_pred = ' '.join(
                c for c in predictions[i]
                if c != self.END and c != self.dict.null_token)
            curr['text'] = curr_pred
            if labels is not None:
                self.answers[valid_inds[i]] = labels[i]
            else:
                self.answers[valid_inds[i]] = curr_pred

        if text_cand_inds is not None:
            for i in range(len(valid_cands)):
                order = text_cand_inds[i]
                _, batch_idx, curr_cands = valid_cands[i]
                curr = batch_reply[valid_inds[batch_idx]]
                curr['text_candidates'] = [
                    curr_cands[idx] for idx in order if idx < len(curr_cands)
                ]

        return batch_reply

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        """Save model parameters if model_file is set."""
        path = self.opt.get('model_file', None) if path is None else path

        if path and hasattr(self, 'optims'):
            model = {}
            model['enc_lt'] = self.enc_lt.state_dict()
            if self.opt['lookuptable'] not in ['enc_dec', 'all']:
                # dec_lt is enc_lt
                raise RuntimeError()
                # model['dec_lt'] = self.dec_lt.state_dict()
            if self.opt['decoder'] != 'shared':
                model['encoder'] = self.encoder.state_dict()
            model['decoder'] = self.decoder.state_dict()
            model['h2e'] = self.h2e.state_dict()
            model['e2o'] = self.e2o.state_dict()
            model['optims'] = {
                k: v.state_dict()
                for k, v in self.optims.items()
            }
            model['longest_label'] = self.longest_label
            model['opt'] = self.opt

            for attn_name in ['attn', 'attn_v', 'attn_combine']:
                if hasattr(self, attn_name):
                    model[attn_name] = getattr(self, attn_name).state_dict()

            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)

        return model['opt'], model

    def set_states(self, states):
        """Set the state dicts of the modules from saved states."""
        self.enc_lt.load_state_dict(states['enc_lt'])
        if self.opt['lookuptable'] not in ['enc_dec', 'all']:
            # dec_lt is enc_lt
            raise RuntimeError(
                'dec_lt state should not exist--it is same as enc_lt.')
        if self.opt['decoder'] != 'shared':
            self.encoder.load_state_dict(states['encoder'])
        self.decoder.load_state_dict(states['decoder'])
        self.h2e.load_state_dict(states['h2e'])
        self.e2o.load_state_dict(states['e2o'])
        for attn_name in ['attn', 'attn_v', 'attn_combine']:
            if attn_name in states:
                getattr(self, attn_name).load_state_dict(states[attn_name])

        for k, v in states['optims'].items():
            self.optims[k].load_state_dict(v)
        self.longest_label = states['longest_label']
コード例 #9
0
class MemnnFeedbackAgent(Agent):
    """
    Memory Network agent for question answering that supports reward-based learning
    (RBI), forward prediction (FP), and imitation learning (IM).

    For more details on settings see: https://arxiv.org/abs/1604.06045.

    Models settings 'FP', 'RBI', 'RBI+FP', and 'IM_feedback' assume that
    feedback and reward for the current example immediatly follow the query
    (add ':feedback' argument when specifying task name).

    python examples/train_model.py --setting 'FP'
    -m "projects.memnn_feedback.agent.memnn_feedback:MemnnFeedbackAgent"
    -t "projects.memnn_feedback.tasks.dbll_babi.agents:taskTeacher:3_p0.5:feedback"
    """
    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        arg_group = argparser.add_argument_group('MemNN Arguments')
        arg_group.add_argument('-lr',
                               '--learning-rate',
                               type=float,
                               default=0.01,
                               help='learning rate')
        arg_group.add_argument('--embedding-size',
                               type=int,
                               default=128,
                               help='size of token embeddings')
        arg_group.add_argument('--hops',
                               type=int,
                               default=3,
                               help='number of memory hops')
        arg_group.add_argument('--mem-size',
                               type=int,
                               default=100,
                               help='size of memory')
        arg_group.add_argument(
            '--time-features',
            type='bool',
            default=True,
            help='use time features for memory embeddings',
        )
        arg_group.add_argument(
            '--position-encoding',
            type='bool',
            default=False,
            help='use position encoding instead of bag of words embedding',
        )
        arg_group.add_argument(
            '-clip',
            '--gradient-clip',
            type=float,
            default=0.2,
            help='gradient clipping using l2 norm',
        )
        arg_group.add_argument('--output',
                               type=str,
                               default='rank',
                               help='type of output (rank|generate)')
        arg_group.add_argument(
            '--rnn-layers',
            type=int,
            default=2,
            help='number of hidden layers in RNN decoder for generative output',
        )
        arg_group.add_argument(
            '--dropout',
            type=float,
            default=0.1,
            help='dropout probability for RNN decoder training',
        )
        arg_group.add_argument('--optimizer',
                               default='sgd',
                               help='optimizer type (sgd|adam)')
        arg_group.add_argument(
            '--no-cuda',
            action='store_true',
            default=False,
            help='disable GPUs even if available',
        )
        arg_group.add_argument('--gpu',
                               type=int,
                               default=-1,
                               help='which GPU device to use')
        arg_group.add_argument(
            '--setting',
            type=str,
            default='IM',
            help='choose among IM, IM_feedback, RBI, FP, RBI+FP',
        )
        arg_group.add_argument(
            '--num-feedback-cands',
            type=int,
            default=6,
            help='number of feedback candidates',
        )
        arg_group.add_argument(
            '--single_embedder',
            type='bool',
            default=False,
            help='number of embedding matrices in the model',
        )

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)

        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.device(opt['gpu'])

        if not shared:
            self.id = 'MemNN'
            self.dict = DictionaryAgent(opt)
            self.decoder = None
            if opt['output'] == 'generate' or opt['output'] == 'g':
                self.decoder = Decoder(
                    opt['embedding_size'],
                    opt['embedding_size'],
                    opt['rnn_layers'],
                    opt,
                    self.dict,
                )
            elif opt['output'] != 'rank' and opt['output'] != 'r':
                raise NotImplementedError('Output type not supported.')

            if 'FP' in opt['setting']:
                # add extra beta-word to indicate learner's answer
                self.beta_word = 'betaword'
                self.dict.add_to_dict([self.beta_word])

            self.model = MemNN(opt, self.dict)

            optim_params = [
                p for p in self.model.parameters() if p.requires_grad
            ]
            lr = opt['learning_rate']
            if opt['optimizer'] == 'sgd':
                self.optimizers = {'memnn': optim.SGD(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.SGD(
                        self.decoder.parameters(), lr=lr)
            elif opt['optimizer'] == 'adam':
                self.optimizers = {'memnn': optim.Adam(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.Adam(
                        self.decoder.parameters(), lr=lr)
            else:
                raise NotImplementedError('Optimizer not supported.')

            if opt['cuda']:
                self.model.share_memory()
                if self.decoder is not None:
                    self.decoder.cuda()

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                print('Loading existing model parameters from ' +
                      opt['model_file'])
                self.load(opt['model_file'])
        else:
            if 'model' in shared:
                # model is shared during hogwild
                self.model = shared['model']
                self.dict = shared['dict']
                self.decoder = shared['decoder']
                self.optimizers = shared['optimizer']
                if 'FP' in opt['setting']:
                    self.beta_word = shared['betaword']

        if hasattr(self, 'model'):
            self.opt = opt
            self.mem_size = opt['mem_size']
            self.loss_fn = CrossEntropyLoss()
            self.gradient_clip = opt.get('gradient_clip', 0.2)

            self.model_setting = opt['setting']
            if 'FP' in opt['setting']:
                self.feedback_cands = set([])
                self.num_feedback_cands = opt['num_feedback_cands']

            self.longest_label = 1
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))

        self.reset()
        self.last_cands, self.last_cands_list = None, None

    def share(self):
        # Share internal states between parent and child instances
        shared = super().share()

        if self.opt.get('numthreads', 1) > 1:
            shared['model'] = self.model
            self.model.share_memory()
            shared['optimizer'] = self.optimizers
            shared['dict'] = self.dict
            shared['decoder'] = self.decoder
            if 'FP' in self.model_setting:
                shared['betaword'] = self.beta_word
        return shared

    def observe(self, observation):
        observation = copy.copy(observation)

        # extract feedback for forward prediction
        # IM setting - no feedback provided in the dataset
        if self.opt['setting'] != 'IM':
            if 'text' in observation:
                split = observation['text'].split('\n')
                feedback = split[-1]
                observation['feedback'] = feedback
                observation['text'] = '\n'.join(split[:-1])

        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = (self.observation['text']
                             if self.observation is not None else '')

            # append answer and feedback (if available) given in the previous example to the previous dialog
            if 'eval_labels' in self.observation:
                prev_dialogue += '\n' + random.choice(
                    self.observation['eval_labels'])
            elif 'labels' in self.observation:
                prev_dialogue += '\n' + random.choice(
                    self.observation['labels'])
            if 'feedback' in self.observation:
                prev_dialogue += '\n' + self.observation['feedback']

            observation['text'] = prev_dialogue + '\n' + observation['text']

        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def reset(self):
        # reset observation and episode_done
        self.observation = None
        self.episode_done = True

    def backward(self, loss, retain_graph=False):
        # zero out optimizer and take one optimization step
        for o in self.optimizers.values():
            o.zero_grad()
        loss.backward(retain_graph=retain_graph)

        torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                      self.gradient_clip)
        for o in self.optimizers.values():
            o.step()

    def parse_cands(self, cand_answers):
        """Returns:
            cand_answers = tensor (vector) of token indices for answer candidates
            cand_answers_lengths = tensor (vector) with lengths of each answer candidate
        """
        parsed_cands = [to_tensors(c, self.dict) for c in cand_answers]
        cand_answers_tensor = torch.cat([x[1] for x in parsed_cands])
        max_cands_len = max([len(c) for c in cand_answers])
        cand_answers_lengths = torch.LongTensor(len(cand_answers),
                                                max_cands_len).zero_()
        for i in range(len(cand_answers)):
            if len(parsed_cands[i][0]) > 0:
                cand_answers_lengths[
                    i, -len(parsed_cands[i][0]):] = parsed_cands[i][0]
        cand_answers_tensor = Variable(cand_answers_tensor)
        cand_answers_lengths = Variable(cand_answers_lengths)
        return cand_answers_tensor, cand_answers_lengths

    def get_cand_embeddings_with_added_beta(self, cands, selected_answer_inds):
        # add beta_word to the candidate selected by the learner to indicate learner's answer
        cand_answers_with_beta = copy.deepcopy(cands)

        for i in range(len(cand_answers_with_beta)):
            cand_answers_with_beta[i][
                selected_answer_inds[i]] += ' ' + self.beta_word

        # get candidate embeddings after adding beta_word to the selected candidate
        (
            cand_answers_tensor_with_beta,
            cand_answers_lengths_with_beta,
        ) = self.parse_cands(cand_answers_with_beta)
        cands_embeddings_with_beta = self.model.answer_embedder(
            cand_answers_lengths_with_beta, cand_answers_tensor_with_beta)
        if self.opt['cuda']:
            cands_embeddings_with_beta = cands_embeddings_with_beta.cuda()
        return cands_embeddings_with_beta

    def predict(self, xs, answer_cands, ys=None, feedback_cands=None):
        is_training = ys is not None
        if is_training and 'FP' not in self.model_setting:
            # Subsample to reduce training time
            answer_cands = [
                list(set(random.sample(c, min(32, len(c))) + self.labels))
                for c in answer_cands
            ]
        else:
            # rank all cands to increase accuracy
            answer_cands = [list(set(c)) for c in answer_cands]

        self.model.train(mode=is_training)

        # Organize inputs for network (see contents of xs and ys in batchify method)
        inputs = [Variable(x, volatile=is_training) for x in xs]

        if self.decoder:
            output_embeddings = self.model(*inputs)
            self.decoder.train(mode=is_training)
            output_lines, loss = self.decode(output_embeddings, ys)
            predictions = self.generated_predictions(output_lines)
            self.backward(loss)
            return predictions

        scores = None
        if is_training:
            label_inds = [
                cand_list.index(self.labels[i])
                for i, cand_list in enumerate(answer_cands)
            ]

            if 'FP' in self.model_setting:
                if len(feedback_cands) == 0:
                    print(
                        'FP is not training... waiting for negative feedback examples'
                    )
                else:
                    cand_answers_embs_with_beta = self.get_cand_embeddings_with_added_beta(
                        answer_cands, label_inds)
                    scores, forward_prediction_output = self.model(
                        *inputs, answer_cands, cand_answers_embs_with_beta)
                    fp_scores = self.model.get_score(feedback_cands,
                                                     forward_prediction_output,
                                                     forward_predict=True)
                    feedback_label_inds = [
                        cand_list.index(self.feedback_labels[i])
                        for i, cand_list in enumerate(feedback_cands)
                    ]
                    if self.opt['cuda']:
                        feedback_label_inds = Variable(
                            torch.cuda.LongTensor(feedback_label_inds))
                    else:
                        feedback_label_inds = Variable(
                            torch.LongTensor(feedback_label_inds))
                    loss_fp = self.loss_fn(fp_scores, feedback_label_inds)
                    if loss_fp.data[0] > 100000:
                        raise Exception("Loss might be diverging. Loss:",
                                        loss_fp.data[0])
                    self.backward(loss_fp, retain_graph=True)

            if self.opt['cuda']:
                label_inds = Variable(torch.cuda.LongTensor(label_inds))
            else:
                label_inds = Variable(torch.LongTensor(label_inds))

        if scores is None:
            output_embeddings = self.model(*inputs)
            scores = self.model.get_score(answer_cands, output_embeddings)

        predictions = self.ranked_predictions(answer_cands, scores)

        if is_training:
            update_params = True
            # don't perform regular training if in FP mode
            if self.model_setting == 'FP':
                update_params = False
            elif 'RBI' in self.model_setting:
                if len(self.rewarded_examples_inds) == 0:
                    update_params = False
                else:
                    self.rewarded_examples_inds = torch.LongTensor(
                        self.rewarded_examples_inds)
                    if self.opt['cuda']:
                        self.rewarded_examples_inds = self.rewarded_examples_inds.cuda(
                        )
                    # use only rewarded examples for training
                    loss = self.loss_fn(
                        scores[self.rewarded_examples_inds, :],
                        label_inds[self.rewarded_examples_inds],
                    )
            else:
                # regular IM training
                loss = self.loss_fn(scores, label_inds)

            if update_params:
                self.backward(loss)
        return predictions

    def ranked_predictions(self, cands, scores):
        _, inds = scores.data.sort(descending=True, dim=1)
        return [[cands[i][j] for j in r if j < len(cands[i])]
                for i, r in enumerate(inds)]

    def decode(self, output_embeddings, ys=None):
        batchsize = output_embeddings.size(0)
        hn = output_embeddings.unsqueeze(0).expand(self.opt['rnn_layers'],
                                                   batchsize,
                                                   output_embeddings.size(1))
        x = self.model.answer_embedder(Variable(torch.LongTensor([1])),
                                       Variable(self.START_TENSOR))
        xes = x.unsqueeze(1).expand(x.size(0), batchsize, x.size(1))

        loss = 0
        output_lines = [[] for _ in range(batchsize)]
        done = [False for _ in range(batchsize)]
        total_done = 0
        idx = 0
        while (total_done < batchsize) and idx < self.longest_label:
            # keep producing tokens until we hit END or max length for each ex
            if self.opt['cuda']:
                xes = xes.cuda()
                hn = hn.contiguous()
            preds, scores = self.decoder(xes, hn)
            if ys is not None:
                y = Variable(ys[0][:, idx])
                temp_y = y.cuda() if self.opt['cuda'] else y
                loss += self.loss_fn(scores, temp_y)
            else:
                y = preds
            # use the true token as the next input for better training
            xes = self.model.answer_embedder(
                Variable(torch.LongTensor(preds.numel()).fill_(1)),
                y).unsqueeze(0)

            for b in range(batchsize):
                if not done[b]:
                    token = self.dict.vec2txt(preds.data[b])
                    if token == self.END:
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)
            idx += 1
        return output_lines, loss

    def generated_predictions(self, output_lines):
        return [[
            ' '.join(c for c in o
                     if c != self.END and c != self.dict.null_token)
        ] for o in output_lines]

    def parse(self, text):
        """Returns:
            query = tensor (vector) of token indices for query
            query_length = length of query
            memory = tensor (matrix) where each row contains token indices for a memory
            memory_lengths = tensor (vector) with lengths of each memory
        """
        sp = text.split('\n')
        query_sentence = sp[-1]
        query = self.dict.txt2vec(query_sentence)
        query = torch.LongTensor(query)
        query_length = torch.LongTensor([len(query)])

        sp = sp[:-1]
        sentences = []
        for s in sp:
            sentences.extend(s.split('\t'))
        if len(sentences) == 0:
            sentences.append(self.dict.null_token)

        num_mems = min(self.mem_size, len(sentences))
        memory_sentences = sentences[-num_mems:]
        memory = [self.dict.txt2vec(s) for s in memory_sentences]
        memory = [torch.LongTensor(m) for m in memory]
        memory_lengths = torch.LongTensor([len(m) for m in memory])
        memory = torch.cat(memory)
        return (query, memory, query_length, memory_lengths)

    def batchify(self, obs):
        """Returns:
            xs = [memories, queries, memory_lengths, query_lengths]
            ys = [labels, label_lengths] (if available, else None)
            cands = list of candidates for each example in batch
            valid_inds = list of indices for examples with valid observations
        """
        exs = [ex for ex in obs if 'text' in ex]
        valid_inds = [i for i, ex in enumerate(obs) if 'text' in ex]
        if not exs:
            return [None] * 5

        if 'RBI' in self.model_setting:
            self.rewarded_examples_inds = [
                i for i, ex in enumerate(obs)
                if 'text' in ex and ex.get('reward', 0) > 0
            ]

        parsed = [self.parse(ex['text']) for ex in exs]
        queries = torch.cat([x[0] for x in parsed])
        memories = torch.cat([x[1] for x in parsed])
        query_lengths = torch.cat([x[2] for x in parsed])
        memory_lengths = torch.LongTensor(len(exs), self.mem_size).zero_()
        for i in range(len(exs)):
            if len(parsed[i][3]) > 0:
                memory_lengths[i, -len(parsed[i][3]):] = parsed[i][3]
        xs = [memories, queries, memory_lengths, query_lengths]

        ys = None
        self.labels = [
            random.choice(ex['labels']) for ex in exs if 'labels' in ex
        ]

        if len(self.labels) == len(exs):
            parsed = [self.dict.txt2vec(l) for l in self.labels]
            parsed = [torch.LongTensor(p) for p in parsed]
            label_lengths = torch.LongTensor([len(p)
                                              for p in parsed]).unsqueeze(1)
            self.longest_label = max(self.longest_label, label_lengths.max())
            padded = [
                torch.cat((
                    p,
                    torch.LongTensor(self.longest_label - len(p)).fill_(
                        self.END_TENSOR[0]),
                )) for p in parsed
            ]
            labels = torch.stack(padded)
            ys = [labels, label_lengths]

        feedback_cands = []
        if 'FP' in self.model_setting:
            self.feedback_labels = [
                ex['feedback'] for ex in exs
                if 'feedback' in ex and ex['feedback'] is not None
            ]
            self.feedback_cands = self.feedback_cands | set(
                self.feedback_labels)

            if (len(self.feedback_labels) == len(exs)
                    and len(self.feedback_cands) > self.num_feedback_cands):
                feedback_cands = [
                    list(
                        set(
                            random.sample(self.feedback_cands,
                                          self.num_feedback_cands) +
                            [feedback])) for feedback in self.feedback_labels
                ]

        cands = [
            ex['label_candidates'] for ex in exs if 'label_candidates' in ex
        ]
        # Use words in dict as candidates if no candidates are provided
        if len(cands) < len(exs):
            cands = build_cands(exs, self.dict)
        # Avoid rebuilding candidate list every batch if its the same
        if self.last_cands != cands:
            self.last_cands = cands
            self.last_cands_list = [list(c) for c in cands]
        cands = self.last_cands_list
        return xs, ys, cands, valid_inds, feedback_cands

    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        xs, ys, cands, valid_inds, feedback_cands = self.batchify(observations)

        if xs is None or len(xs[1]) == 0:
            return batch_reply

        # Either train or predict
        predictions = self.predict(xs, cands, ys, feedback_cands)

        for i in range(len(valid_inds)):
            batch_reply[valid_inds[i]]['text'] = predictions[i][0]
            batch_reply[valid_inds[i]]['text_candidates'] = predictions[i]
        return batch_reply

    def act(self):
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['memnn'] = self.model.state_dict()
            checkpoint['memnn_optim'] = self.optimizers['memnn'].state_dict()
            if self.decoder is not None:
                checkpoint['decoder'] = self.decoder.state_dict()
                checkpoint['decoder_optim'] = self.optimizers[
                    'decoder'].state_dict()
                checkpoint['longest_label'] = self.longest_label
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)

    def load(self, path):
        with open(path, 'rb') as read:
            checkpoint = torch.load(read)
        self.model.load_state_dict(checkpoint['memnn'])
        self.optimizers['memnn'].load_state_dict(checkpoint['memnn_optim'])
        if self.decoder is not None:
            self.decoder.load_state_dict(checkpoint['decoder'])
            self.optimizers['decoder'].load_state_dict(
                checkpoint['decoder_optim'])
            self.longest_label = checkpoint['longest_label']
コード例 #10
0
ファイル: seq2seq.py プロジェクト: analyticlaks/ParlAI
class Seq2seqAgent(Agent):
    """Simple agent which uses an LSTM to process incoming text observations."""

    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs', '--hiddensize', type=int, default=64,
            help='size of the hidden layers and embeddings')
        agent.add_argument('-nl', '--numlayers', type=int, default=2,
            help='number of hidden layers')
        agent.add_argument('-lr', '--learningrate', type=float, default=0.5,
            help='learning rate')
        agent.add_argument('-dr', '--dropout', type=float, default=0.1,
            help='dropout rate')
        agent.add_argument('--no-cuda', action='store_true', default=False,
            help='disable GPUs even if available')
        agent.add_argument('--gpu', type=int, default=-1,
            help='which GPU device to use')

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.set_device(opt['gpu'])
        if not shared:
            # don't enter this loop for shared (ie batch) instantiations
            self.dict = DictionaryAgent(opt)
            self.id = 'Seq2Seq'
            hsz = opt['hiddensize']
            self.EOS = self.dict.eos_token
            self.observation = {'text': self.EOS, 'episode_done': True}
            self.EOS_TENSOR = torch.LongTensor(self.dict.parse(self.EOS))
            self.hidden_size = hsz
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learningrate']
            self.use_cuda = opt.get('cuda', False)
            self.longest_label = 1

            self.criterion = nn.NLLLoss()
            self.lt = nn.Embedding(len(self.dict), hsz, padding_idx=0,
                                   scale_grad_by_freq=True)
            self.encoder = nn.GRU(hsz, hsz, opt['numlayers'])
            self.decoder = nn.GRU(hsz, hsz, opt['numlayers'])
            self.d2o = nn.Linear(hsz, len(self.dict))
            self.dropout = nn.Dropout(opt['dropout'])
            self.softmax = nn.LogSoftmax()

            lr = opt['learningrate']
            self.optims = {
                'lt': optim.SGD(self.lt.parameters(), lr=lr),
                'encoder': optim.SGD(self.encoder.parameters(), lr=lr),
                'decoder': optim.SGD(self.decoder.parameters(), lr=lr),
                'd2o': optim.SGD(self.d2o.parameters(), lr=lr),
            }
            if self.use_cuda:
                self.cuda()
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                print('Loading existing model parameters from ' + opt['model_file'])
                self.load(opt['model_file'])

        self.episode_done = True

    def parse(self, text):
        return torch.LongTensor(self.dict.txt2vec(text))

    def v2t(self, vec):
        return self.dict.vec2txt(vec)

    def cuda(self):
        self.criterion.cuda()
        self.lt.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        self.d2o.cuda()
        self.dropout.cuda()
        self.softmax.cuda()

    def hidden_to_idx(self, hidden, drop=False):
        if hidden.size(0) > 1:
            raise RuntimeError('bad dimensions of tensor:', hidden)
        hidden = hidden.squeeze(0)
        scores = self.d2o(hidden)
        if drop:
            scores = self.dropout(scores)
        scores = self.softmax(scores)
        _max_score, idx = scores.max(1)
        return idx, scores

    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()

    def init_zeros(self, bsz=1):
        t = torch.zeros(self.num_layers, bsz, self.hidden_size)
        if self.use_cuda:
            t = t.cuda(async=True)
        return Variable(t)

    def init_rand(self, bsz=1):
        t = torch.FloatTensor(self.num_layers, bsz, self.hidden_size)
        t.uniform_(0.05)
        if self.use_cuda:
            t = t.cuda(async=True)
        return Variable(t)

    def observe(self, observation):
        observation = copy.deepcopy(observation)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def update(self, xs, ys):
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs).t()
        h0 = self.init_zeros(batchsize)
        _output, hn = self.encoder(xes, h0)

        # start with EOS tensor for all
        x = self.EOS_TENSOR
        if self.use_cuda:
            x = x.cuda(async=True)
        x = Variable(x)
        xe = self.lt(x).unsqueeze(1)
        xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        output_lines = [[] for _ in range(batchsize)]

        self.zero_grad()
        # update model
        loss = 0
        self.longest_label = max(self.longest_label, ys.size(1))
        for i in range(ys.size(1)):
            output, hn = self.decoder(xes, hn)
            preds, scores = self.hidden_to_idx(output, drop=True)
            y = ys.select(1, i)
            loss += self.criterion(scores, y)
            # use the true token as the next input
            xes = self.lt(y).unsqueeze(0)
            # hn = self.dropout(hn)
            for j in range(preds.size(0)):
                token = self.v2t([preds.data[j][0]])
                output_lines[j].append(token)

        loss.backward()
        self.update_params()

        if random.random() < 0.1:
            true = self.v2t(ys.data[0])
            #print('loss:', round(loss.data[0], 2),
            #      ' '.join(output_lines[0]), '(true: {})'.format(true))
        return output_lines

    def predict(self, xs):
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs).t()
        h0 = self.init_zeros(batchsize)
        _output, hn = self.encoder(xes, h0)

        # start with EOS tensor for all
        x = self.EOS_TENSOR
        if self.use_cuda:
            x = x.cuda(async=True)
        x = Variable(x)
        xe = self.lt(x).unsqueeze(1)
        xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0
        output_lines = [[] for _ in range(batchsize)]

        while(total_done < batchsize) and max_len < self.longest_label:
            output, hn = self.decoder(xes, hn)
            preds, scores = self.hidden_to_idx(output, drop=False)
            xes = self.lt(preds.t())
            max_len += 1
            for i in range(preds.size(0)):
                if not done[i]:
                    token = self.v2t(preds.data[i])
                    if token == self.EOS:
                        done[i] = True
                        total_done += 1
                    else:
                        output_lines[i].append(token)
        if random.random() < 0.1:
            print('prediction:', ' '.join(output_lines[0]))
        return output_lines

    def batchify(self, obs):
        exs = [ex for ex in obs if 'text' in ex]
        valid_inds = [i for i, ex in enumerate(obs) if 'text' in ex]

        batchsize = len(exs)
        parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max([len(x) for x in parsed])
        xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
        for i, x in enumerate(parsed):
            offset = max_x_len - len(x)
            for j, idx in enumerate(x):
                xs[i][j + offset] = idx
        if self.use_cuda:
            xs = xs.cuda(async=True)
        xs = Variable(xs)

        ys = None
        if 'labels' in exs[0]:
            labels = [random.choice(ex['labels']) + ' ' + self.EOS for ex in exs]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                ys = ys.cuda(async=True)
            ys = Variable(ys)
        return xs, ys, valid_inds

    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        xs, ys, valid_inds = self.batchify(observations)

        if len(xs) == 0:
            return batch_reply

        # Either train or predict
        if ys is not None:
            predictions = self.update(xs, ys)
        else:
            predictions = self.predict(xs)

        for i in range(len(predictions)):
            batch_reply[valid_inds[i]]['text'] = ' '.join(
                c for c in predictions[i] if c != self.EOS)

        return batch_reply

    def act(self):
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            model = {}
            model['lt'] = self.lt.state_dict()
            model['encoder'] = self.encoder.state_dict()
            model['decoder'] = self.decoder.state_dict()
            model['d2o'] = self.d2o.state_dict()
            model['longest_label'] = self.longest_label

            with open(path, 'wb') as write:
                torch.save(model, write)

    def load(self, path):
        with open(path, 'rb') as read:
            model = torch.load(read)

        self.lt.load_state_dict(model['lt'])
        self.encoder.load_state_dict(model['encoder'])
        self.decoder.load_state_dict(model['decoder'])
        self.d2o.load_state_dict(model['d2o'])
        self.longest_label = model['longest_label']
コード例 #11
0
ファイル: hciae.py プロジェクト: wolegechu/ParlAI
class HCIAEAgent(Agent):
    """ HCIAEAgent.
    """

    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        arg_group = argparser.add_argument_group('HCIAE Arguments')
        arg_group.add_argument('--dropout', type=float, default=.1, help='')
        arg_group.add_argument('--embedding-size', type=int, default=512, help='')
        arg_group.add_argument('--hidden-size', type=int, default=512, help='')
        arg_group.add_argument('--no-cuda', action='store_true', default=False,
                               help='disable GPUs even if available')
        arg_group.add_argument('--gpu', type=int, default=-1,
                               help='which GPU device to use')
        arg_group.add_argument('--rnn-layers', type=int, default=2,
            help='number of hidden layers in RNN decoder for generative output')
        arg_group.add_argument('--optimizer', default='adam',
            help='optimizer type (sgd|adam)')
        arg_group.add_argument('-lr', '--learning-rate', type=float, default=0.01,
                               help='learning rate')

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[Using CUDA]')
            torch.cuda.device(opt['gpu'])
        
        if not shared:
            self.opt = opt
            self.id = 'HCIAE'
            self.dict = DictionaryAgent(opt)
            self.answers = [None] * opt['batchsize']

            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))
            self.mem_size = 10
            self.longest_label = 1
            self.writer = SummaryWriter()
            self.writer_idx = 0

            lr = opt['learning_rate']

            self.loss_fn = CrossEntropyLoss()

            self.model = HCIAE(opt, self.dict)
            self.decoder = Decoder(opt['hidden_size'], opt['hidden_size'], opt['rnn_layers'], opt, self.dict)

            optim_params = [p for p in self.model.parameters() if p.requires_grad]
            if opt['optimizer'] == 'sgd':
                self.optimizers = {'hciae': optim.SGD(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.SGD(self.decoder.parameters(), lr=lr)
            elif opt['optimizer'] == 'adam':
                self.optimizers = {'hciae': optim.Adam(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.Adam(self.decoder.parameters(), lr=lr)
            else:
                raise NotImplementedError('Optimizer not supported.')


            if opt['cuda']:
                self.decoder.cuda()
            
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                print('Loading existing model parameters from ' + opt['model_file'])
        else:
            self.answers = shared['answers']
        
        self.episode_done = True
        self.img_feature = None
        self.last_cands, self.last_cands_list = None, None
    def share(self):
        shared = super().share()
        shared['answers'] = self.answers
        return shared

    def observe(self, observation):
        observation = copy.copy(observation)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text'] if self.observation is not None else ''
            prev_dialogue = prev_dialogue + ' __END__ ' + self.observation['labels'][0]
            observation['text'] = prev_dialogue + '\n' + observation['text']
        # else:
        #     self.img_feature = torch.from_numpy(observation['image'].items()[0][1])
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def parse(self, text):
        """Returns:
            query = tensor (vector) of token indices for query
            query_length = length of query
            memory = tensor (matrix) where each row contains token indices for a memory
            memory_lengths = tensor (vector) with lengths of each memory
        """
        sp = text.split('\n')
        query_sentence = sp[-1]
        query = self.dict.txt2vec(query_sentence)
        query = torch.LongTensor(query)
        query_length = torch.LongTensor([len(query)])

        sp = sp[:-1]
        sentences = []
        for s in sp:
            sentences.extend(s.split('\t'))
        if len(sentences) == 0:
            sentences.append(self.dict.null_token)

        num_mems = min(self.mem_size, len(sentences))
        memory_sentences = sentences[-num_mems:]
        memory = [self.dict.txt2vec(s) for s in memory_sentences]
        memory = [torch.LongTensor(m) for m in memory]
        memory_lengths = torch.LongTensor([len(m) for m in memory])
        memory = torch.cat(memory)
        return (query, memory, query_length, memory_lengths)

    def batchify(self, obs):
        """Returns:
            xs = [memories, queries, memory_lengths, query_lengths]
            ys = [labels, label_lengths] (if available, else None)
            (deleted) cands = list of candidates for each example in batch
            valid_inds = list of indices for examples with valid observations
        """
        exs = [ex for ex in obs if ('text' in ex and 'image' in ex)]
        valid_inds = [i for i, ex in enumerate(obs) if ('text' in ex and 'image' in ex)]
        if not exs:
            return [None] * 4
        images = torch.cat([torch.from_numpy(exs[i]['image']['x']).unsqueeze(0) for i in valid_inds])
        parsed = [self.parse(exs[i]['text']) for i in valid_inds]
        queries = torch.cat([x[0] for x in parsed])
        memories = torch.cat([x[1] for x in parsed])
        query_lengths = torch.cat([x[2] for x in parsed])
        memory_lengths = torch.LongTensor(len(exs), self.mem_size).zero_()
        for i in range(len(exs)):
            if len(parsed[i][3]) > 0:
                memory_lengths[i, -len(parsed[i][3]):] = parsed[i][3]

        # bachify memories (batchsize * memory_length * max_sentence_length)
        batch_size = len(valid_inds)
        start_idx = memory_lengths.numpy()[0].nonzero()[0][0]
        idx = 0
        memories_tensor = []
        max_len = torch.max(memory_lengths)
        for i in range(batch_size):
            memory = []
            for j in range(start_idx, 10):
                temp = []
                length = memory_lengths[i][j]
                temp = [memories[idx + i] for i in range(length)]
                temp.extend([0] * (max_len - length))
                idx += length
                memory.append(temp)
            memories_tensor.append(memory)
        memories_tensor = torch.from_numpy(np.array(memories_tensor))

        # bachify queries (batch_size * max_query_length)
        idx = 0
        max_length = max(query_lengths)

        queries_tensor = []
        for i in range(batch_size):
            temp = []
            length = query_lengths[i]
            temp = [queries[idx+i] for i in range(length)]
            temp.extend([0] * (max_length - length))
            idx += length
            queries_tensor.append(temp)
        queries_tensor = torch.from_numpy(np.array(queries_tensor))
        xs = [memories_tensor, queries_tensor, memory_lengths, query_lengths, images]
        ys = None
        self.labels = [random.choice(ex['labels']) for ex in exs if 'labels' in ex]
        if len(self.labels) == len(exs):
            parsed = [self.dict.txt2vec(l) for l in self.labels]
            parsed = [torch.LongTensor(p) for p in parsed]
            label_lengths = torch.LongTensor([len(p) for p in parsed]).unsqueeze(1)
            self.longest_label = max(self.longest_label, label_lengths.max())
            padded = [torch.cat((p, torch.LongTensor(self.longest_label - len(p))
                        .fill_(self.END_TENSOR[0]))) for p in parsed]
            labels = torch.stack(padded)
            ys = [labels, label_lengths]

        return xs, ys, valid_inds

    def predict(self, xs, ys=None):
        is_training = ys is not None
        self.model.train(mode=is_training)
        inputs = [Variable(x) for x in xs]
        output = self.model(*inputs)

        self.decoder.train(mode=is_training)

        output_lines, loss = self.decode(output, ys)
        predictions = self.generated_predictions(output_lines)

        if is_training:
            for o in self.optimizers.values():
                o.zero_grad()
            loss.backward()
            for o in self.optimizers.values():
                o.step()
            self.writer_idx += 1
            #print('Loss: ', loss.data[0])

            #if self.writer_idx == 1:
            #    self.writer.add_graph(self.model, output)

            self.writer.add_histogram('loss', loss.data[0], self.writer_idx)
            self.writer.add_embedding(output.data, tag='output', global_step=self.writer_idx)
            if random.random() < 0.25:
                label = self.dict.vec2txt(ys[0][0].tolist())
                self.writer.add_text('prediction - label', ' '.join(predictions[0]) + ' --- ' + label, self.writer_idx)

        return predictions
    

    def decode(self, output_embeddings, ys=None):
        # output_embedding [batich_size, hidden_size[
        batchsize = output_embeddings.size(0)
        hn = output_embeddings.unsqueeze(0).expand(
            self.opt['rnn_layers'], batchsize, output_embeddings.size(1))
        x = self.model.answer_embedder(Variable(self.START_TENSOR))
        xes = x.unsqueeze(1).expand(x.size(0), batchsize, x.size(1))

        loss = 0
        output_lines =[[] for _ in range(batchsize)]
        done = [False for _ in range(batchsize)]
        total_done = 0
        idx = 0

        while (total_done < batchsize) and idx < self.longest_label:
            # keep producing tokens until we hit END or max length for each ex
            if self.opt['cuda']:
                xes = xes.cuda()
                hn = hn.contiguous()
            #print('Before Decoder size - xes, hn', xes.size(), hn.size())
            preds, scores = self.decoder(xes, hn)
            if ys is not None:
                y = Variable(ys[0][:, idx])
                temp_y = y.cuda() if self.opt['cuda'] else y
                loss += self.loss_fn(scores, temp_y)
            else:
                y = preds
            # use the true token as the next input for better training
            xes = self.model.answer_embedder(y).unsqueeze(0)

            for b in  range(batchsize):
                if not done[b]:
                    token = self.dict.vec2txt([preds.data[b]])
                    if token == self.END:
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)
            idx += 1

        return output_lines, loss
    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]
        xs, ys, valid_inds = self.batchify(observations)

        if xs is None or len(xs[1]) == 0:
            return batch_reply

        # Either train or predict
        predictions = self.predict(xs, ys)

        for i in range(len(valid_inds)):
            #self.answers[valid_inds[i]] = predictions[i][0]
            batch_reply[valid_inds[i]]['text'] = predictions[i][0]
            #batch_reply[valid_inds[i]]['text_candidates'] = predictions[i]
        return batch_reply

    def generated_predictions(self, output_lines):
        return [[' '.join(c for c in o if c != self.END
                        and c != self.dict.null_token)] for o in output_lines]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path
        now_time = time.strftime('%m-%d %H:%M', time.localtime(time.time()))
        path = path + now_time + '.model'
        if path:
            checkpoint = {}
            checkpoint['hciae'] = self.model.state_dict()
            checkpoint['hciae_optim'] = self.optimizers['memnn'].state_dict()
            if self.decoder is not None:
                checkpoint['decoder'] = self.decoder.state_dict()
                checkpoint['decoder_optim'] = self.optimizers['decoder'].state_dict()
                checkpoint['longest_label'] = self.longest_label
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)
コード例 #12
0
ファイル: seq2seq_v2.py プロジェクト: zhongyunuestc/convai
class Seq2seqV2Agent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Sequence to Sequence Learning with Neural
    Networks `(Sutskever et al. 2014) <https://arxiv.org/abs/1409.3215>`_.
    """

    OPTIM_OPTS = {
        'adadelta': optim.Adadelta,
        'adagrad': optim.Adagrad,
        'adam': optim.Adam,
        'adamax': optim.Adamax,
        'asgd': optim.ASGD,
        'lbfgs': optim.LBFGS,
        'rmsprop': optim.RMSprop,
        'rprop': optim.Rprop,
        'sgd': optim.SGD,
    }

    ENC_OPTS = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs',
                           '--hiddensize',
                           type=int,
                           default=128,
                           help='size of the hidden layers')
        agent.add_argument('-emb',
                           '--embeddingsize',
                           type=int,
                           default=128,
                           help='size of the token embeddings')
        agent.add_argument('-nl',
                           '--numlayers',
                           type=int,
                           default=2,
                           help='number of hidden layers')
        agent.add_argument('-lr',
                           '--learning_rate',
                           type=float,
                           default=0.5,
                           help='learning rate')
        agent.add_argument('-wd',
                           '--weight_decay',
                           type=float,
                           default=0,
                           help='weight decay')
        agent.add_argument('-dr',
                           '--dropout',
                           type=float,
                           default=0.2,
                           help='dropout rate')
        agent.add_argument('-att',
                           '--attention',
                           default=False,
                           type='bool',
                           help='if True, use attention')
        agent.add_argument(
            '-attType',
            '--attn-type',
            default='general',
            choices=['general', 'concat', 'dot'],
            help='general=bilinear dotproduct, concat=bahdanau\'s implemenation'
        )
        agent.add_argument('--no-cuda',
                           action='store_true',
                           default=False,
                           help='disable GPUs even if available')
        agent.add_argument('--gpu',
                           type=int,
                           default=-1,
                           help='which GPU device to use')
        agent.add_argument('-rc',
                           '--rank-candidates',
                           type='bool',
                           default=False,
                           help='rank candidates if available. this is done by'
                           ' computing the mean score per token for each '
                           'candidate and selecting the highest scoring.')
        agent.add_argument('-tr',
                           '--truncate',
                           type='bool',
                           default=True,
                           help='truncate input & output lengths to speed up '
                           'training (may reduce accuracy). This fixes all '
                           'input and output to have a maximum length and to '
                           'be similar in length to one another by throwing '
                           'away extra tokens. This reduces the total amount '
                           'of padding in the batches.')
        agent.add_argument('-enc',
                           '--encoder',
                           default='gru',
                           choices=Seq2seqV2Agent.ENC_OPTS.keys(),
                           help='Choose between different encoder modules.')
        agent.add_argument('-bi',
                           '--bi-encoder',
                           default=True,
                           type='bool',
                           help='Bidirection of encoder')
        agent.add_argument('-dec',
                           '--decoder',
                           default='same',
                           choices=['same', 'shared'] +
                           list(Seq2seqV2Agent.ENC_OPTS.keys()),
                           help='Choose between different decoder modules. '
                           'Default "same" uses same class as encoder, '
                           'while "shared" also uses the same weights.')
        agent.add_argument('-opt',
                           '--optimizer',
                           default='sgd',
                           choices=Seq2seqV2Agent.OPTIM_OPTS.keys(),
                           help='Choose between pytorch optimizers. '
                           'Any member of torch.optim is valid and will '
                           'be used with default params except learning '
                           'rate (as specified by -lr).')
        agent.add_argument('-gradClip',
                           '--grad-clip',
                           type=float,
                           default=-1,
                           help='gradient clip, default = -1 (no clipping)')
        agent.add_argument(
            '-epi',
            '--episode-concat',
            type='bool',
            default=False,
            help=
            'If multiple observations are from the same episode, concatenate them.'
        )
        agent.add_argument(
            '--beam_size',
            type=int,
            default=0,
            help=
            'Beam size for beam search (only for generation mode) \n For Greedy search set 0'
        )
        agent.add_argument('--max_seq_len',
                           type=int,
                           default=50,
                           help='The maximum sequence length, default = 50')

    def __init__(self, opt, shared=None):
        """Set up model if shared params not set, otherwise no work to do."""
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            # check for cuda
            self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available(
            )
            if self.use_cuda:
                print('[ Using CUDA ]')
                torch.cuda.set_device(opt['gpu'])

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, self.states = self.load(opt['model_file'])
                # override options with stored ones
                opt = self.override_opt(new_opt)

            self.dict = DictionaryAgent(opt)
            self.id = 'Seq2Seq'
            # we use START markers to start our output
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))
            # we use END markers to end our output
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            # get index of null token from dictionary (probably 0)
            self.NULL_IDX = self.dict.txt2vec(self.dict.null_token)[0]

            # store important params directly
            hsz = opt['hiddensize']
            emb = opt['embeddingsize']
            self.hidden_size = hsz
            self.emb_size = emb
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learning_rate']
            self.rank = opt['rank_candidates']
            self.longest_label = 1
            self.truncate = opt['truncate']
            self.attention = opt['attention']

            # set up tensors
            if self.opt['bi_encoder']:
                self.zeros = torch.zeros(2 * self.num_layers, 1, hsz)
            else:
                self.zeros = torch.zeros(self.num_layers, 1, hsz)

            self.zeros_dec = torch.zeros(self.num_layers, 1, hsz)

            self.xs = torch.LongTensor(1, 1)
            self.ys = torch.LongTensor(1, 1)
            self.cands = torch.LongTensor(1, 1, 1)
            self.cand_scores = torch.FloatTensor(1)
            self.cand_lengths = torch.LongTensor(1)

            # set up modules
            self.criterion = nn.NLLLoss(size_average=False, ignore_index=0)

            # lookup table stores word embeddings
            self.lt = nn.Embedding(len(self.dict),
                                   emb,
                                   padding_idx=self.NULL_IDX)
            #scale_grad_by_freq=True)
            # encoder captures the input text
            enc_class = Seq2seqV2Agent.ENC_OPTS[opt['encoder']]
            self.encoder = enc_class(emb,
                                     hsz,
                                     opt['numlayers'],
                                     bidirectional=opt['bi_encoder'],
                                     dropout=opt['dropout'])
            # decoder produces our output states

            #if opt['decoder'] == 'shared':
            #    self.decoder = self.encoder
            dec_isz = emb + hsz
            if opt['bi_encoder']:
                dec_isz += hsz

            if opt['decoder'] == 'same':
                self.decoder = enc_class(dec_isz,
                                         hsz,
                                         opt['numlayers'],
                                         dropout=opt['dropout'])
            else:
                dec_class = Seq2seqV2Agent.ENC_OPTS[opt['decoder']]
                self.decoder = dec_class(dec_isz,
                                         hsz,
                                         opt['numlayers'],
                                         dropout=opt['dropout'])

            # linear layer helps us produce outputs from final decoder state
            self.h2o = nn.Linear(hsz, len(self.dict))
            # droput on the linear layer helps us generalize
            self.dropout = nn.Dropout(opt['dropout'])

            self.use_attention = False
            self.attn = None
            # if attention is greater than 0, set up additional members
            if self.attention:
                self.use_attention = True
                self.att_type = opt['attn_type']
                input_size = hsz
                if opt['bi_encoder']:
                    input_size += hsz

                if self.att_type == 'concat':
                    self.attn = nn.Linear(input_size + hsz, 1, bias=False)
                elif self.att_type == 'dot':
                    assert not opt['bi_encoder']
                elif self.att_type == 'general':
                    self.attn = nn.Linear(hsz, input_size, bias=False)

            # initialization
            """
                getattr(self, 'lt').weight.data.uniform_(-0.1, 0.1)
                for module in {'encoder', 'decoder'}:
                    for weight in getattr(self, module).parameters():
                        weight.data.normal_(0, 0.05)
                    #for bias in getattr(self, module).parameters():
                    #    bias.data.fill_(0)
                        
                for module in {'h2o', 'attn'}:
                    if hasattr(self, module):
                        getattr(self, module).weight.data.normal_(0, 0.01)
                        #getattr(self, module).bias.data.fill_(0)
            """

            # set up optims for each module
            self.lr = opt['learning_rate']
            self.wd = opt['weight_decay']

            optim_class = Seq2seqV2Agent.OPTIM_OPTS[opt['optimizer']]
            self.optims = {
                'lt':
                optim_class(self.lt.parameters(), lr=self.lr),
                'encoder':
                optim_class(self.encoder.parameters(), lr=self.lr),
                'decoder':
                optim_class(self.decoder.parameters(), lr=self.lr),
                'h2o':
                optim_class(self.h2o.parameters(),
                            lr=self.lr,
                            weight_decay=self.wd),
            }
            if self.attention and self.attn is not None:
                self.optims.update({
                    'attn':
                    optim_class(self.attn.parameters(),
                                lr=self.lr,
                                weight_decay=self.wd)
                })

            if hasattr(self, 'states'):
                # set loaded states if applicable
                self.set_states(self.states)

            if self.use_cuda:
                self.cuda()

            self.loss = 0
            self.ndata = 0
            self.loss_valid = 0
            self.ndata_valid = 0

            if opt['beam_size'] > 0:
                self.beamsize = opt['beam_size']

        self.episode_concat = opt['episode_concat']
        self.training = True
        self.generating = False
        self.local_human = False

        if opt.get('max_seq_len') is not None:
            self.max_seq_len = opt['max_seq_len']
        else:
            self.max_seq_len = opt['max_seq_len'] = 50
        self.reset()

    def set_lrate(self, lr):
        self.lr = lr
        for key in self.optims:
            self.optims[key].param_groups[0]['lr'] = self.lr

    def override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'hiddensize', 'embeddingsize', 'numlayers', 'optimizer', 'encoder',
            'decoder'
        }
        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def parse(self, text):
        """Convert string to token indices."""
        return self.dict.txt2vec(text)

    def v2t(self, vec):
        """Convert token indices to string of tokens."""
        return self.dict.vec2txt(vec)

    def cuda(self):
        """Push parameters to the GPU."""
        self.START_TENSOR = self.START_TENSOR.cuda(async=True)
        self.END_TENSOR = self.END_TENSOR.cuda(async=True)
        self.zeros = self.zeros.cuda(async=True)
        self.zeros_dec = self.zeros_dec.cuda(async=True)
        self.xs = self.xs.cuda(async=True)
        self.ys = self.ys.cuda(async=True)
        self.cands = self.cands.cuda(async=True)
        self.cand_scores = self.cand_scores.cuda(async=True)
        self.cand_lengths = self.cand_lengths.cuda(async=True)
        self.criterion.cuda()
        self.lt.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        self.h2o.cuda()
        self.dropout.cuda()
        if self.use_attention:
            self.attn.cuda()

    def hidden_to_idx(self, hidden, dropout=False):
        """Convert hidden state vectors into indices into the dictionary."""
        if hidden.size(0) > 1:
            raise RuntimeError('bad dimensions of tensor:', hidden)
        hidden = hidden.squeeze(0)
        if dropout:
            hidden = self.dropout(hidden)  # dropout over the last hidden
        scores = self.h2o(hidden)
        scores = F.log_softmax(scores)
        _max_score, idx = scores.max(1)
        return idx, scores

    def zero_grad(self):
        """Zero out optimizers."""
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        """Do one optimization step."""
        for optimizer in self.optims.values():
            optimizer.step()

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def preprocess(self, reply_text):
        # preprocess for opensub
        reply_text = reply_text.replace('\\n', '\n')  ## TODO: pre-processing
        reply_text = reply_text.replace("'m", " 'm")
        reply_text = reply_text.replace("'ve", " 've")
        reply_text = reply_text.replace("'s", " 's")
        reply_text = reply_text.replace("'t", " 't")
        reply_text = reply_text.replace("'il", " 'il")
        reply_text = reply_text.replace("'d", " 'd")
        reply_text = reply_text.replace("'re", " 're")
        reply_text = reply_text.lower().strip()

        return reply_text

    def observe(self, observation):
        """Save observation for act.
        If multiple observations are from the same episode, concatenate them.
        """
        if self.local_human:
            observation = {}
            observation['id'] = self.getID()
            reply_text = input("Enter Your Message: ")
            reply_text = self.preprocess(reply_text)
            observation['episode_done'] = True  ### TODO: for history
            """
            if '[DONE]' in reply_text:
                reply['episode_done'] = True
                self.episodeDone = True
                reply_text = reply_text.replace('[DONE]', '')
            """
            observation['text'] = reply_text

        else:
            # shallow copy observation (deep copy can be expensive)
            observation = observation.copy()
            if not self.episode_done and self.episode_concat:
                # if the last example wasn't the end of an episode, then we need to
                # recall what was said in that example
                prev_dialogue = self.observation['text']
                observation['text'] = prev_dialogue + '\n' + observation[
                    'text']  #### TODO!!!! # DATA is concatenated!!

        self.observation = observation
        self.episode_done = observation['episode_done']

        return observation

    def _encode(self, xs, xlen, dropout=False, packed=True):
        """Call encoder and return output and hidden states."""
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs).transpose(0, 1)
        #if dropout:
        #    xes = self.dropout(xes)

        # initial hidden
        if self.zeros.size(1) != batchsize:
            if self.opt['bi_encoder']:
                self.zeros.resize_(2 * self.num_layers, batchsize,
                                   self.hidden_size).fill_(0)
            else:
                self.zeros.resize_(self.num_layers, batchsize,
                                   self.hidden_size).fill_(0)

        h0 = Variable(self.zeros.fill_(0))

        # forward
        if packed:
            xes = torch.nn.utils.rnn.pack_padded_sequence(xes, xlen)

        if type(self.encoder) == nn.LSTM:
            encoder_output, _ = self.encoder(
                xes, (h0, h0))  ## Note : we can put None instead of (h0, h0)
        else:
            encoder_output, _ = self.encoder(xes, h0)

        if packed:
            encoder_output, _ = torch.nn.utils.rnn.pad_packed_sequence(
                encoder_output)

        encoder_output = encoder_output.transpose(0, 1)  #batch first
        """
        if self.use_attention:
            if encoder_output.size(1) > self.max_length:
                offset = encoder_output.size(1) - self.max_length
                encoder_output = encoder_output.narrow(1, offset, self.max_length)
        """

        return encoder_output

    def _apply_attention(self, word_input, encoder_output, last_hidden, xs):
        """Apply attention to encoder hidden layer."""
        batch_size = encoder_output.size(0)
        enc_length = encoder_output.size(1)
        mask = Variable(xs.data.eq(0).eq(0).float())

        #pdb.set_trace()
        # encoder_output # B x T x 2H
        # last_hidden  B x H

        if self.att_type == 'concat':
            last_hidden = last_hidden.unsqueeze(1).expand(
                batch_size, encoder_output.size(1),
                self.hidden_size)  # B x T x H
            attn_weights = F.tanh(
                self.attn(
                    torch.cat((encoder_output, last_hidden),
                              2).view(batch_size * enc_length,
                                      -1)).view(batch_size, enc_length))
        elif self.att_type == 'dot':
            attn_weights = F.tanh(
                torch.bmm(encoder_output, last_hidden.unsqueeze(2)).squeeze())
        elif self.att_type == 'general':
            attn_weights = F.tanh(
                torch.bmm(encoder_output,
                          self.attn(last_hidden).unsqueeze(2)).squeeze())

        #attn_weights = F.softmax(attn_weights.view(batch_size, enc_length))

        attn_weights = attn_weights.exp().mul(mask)
        denom = attn_weights.sum(1).unsqueeze(1).expand_as(attn_weights)
        attn_weights = attn_weights.div(denom)
        context = torch.bmm(attn_weights.unsqueeze(1),
                            encoder_output).squeeze(1)

        output = torch.cat((word_input, context.unsqueeze(0)), 2)
        return output

    def _get_context(self, batchsize, xlen_t, encoder_output):
        " return initial hidden of decoder and encoder context (last_state)"

        ## The initial of decoder is the hidden (last states) of encoder --> put zero!
        if self.zeros_dec.size(1) != batchsize:
            self.zeros_dec.resize_(self.num_layers, batchsize,
                                   self.hidden_size).fill_(0)
        hidden = Variable(self.zeros_dec.fill_(0))

        last_state = None
        if not self.use_attention:
            last_state = torch.gather(
                encoder_output, 1,
                xlen_t.view(-1, 1, 1).expand(encoder_output.size(0), 1,
                                             encoder_output.size(2)))
            if self.opt['bi_encoder']:
                #                last_state = torch.cat((encoder_output[:,0,:self.hidden_size], last_state[:,0,self.hidden_size:]),1)
                last_state = torch.cat(
                    (encoder_output[:, 0, self.hidden_size:],
                     last_state[:, 0, :self.hidden_size]), 1)

        return hidden, last_state

    def _decode_and_train(self, batchsize, dec_xes, xlen_t, xs, ys, ylen,
                          encoder_output):
        # update the model based on the labels
        self.zero_grad()
        loss = 0

        output_lines = [[] for _ in range(batchsize)]

        # keep track of longest label we've ever seen
        self.longest_label = max(self.longest_label, ys.size(1))

        hidden, last_state = self._get_context(batchsize, xlen_t,
                                               encoder_output)

        for i in range(ys.size(1)):
            if self.use_attention:
                output = self._apply_attention(dec_xes, encoder_output,
                                               hidden[-1], xs)
            else:
                output = torch.cat((dec_xes, last_state.unsqueeze(0)), 2)

            output, hidden = self.decoder(output, hidden)
            preds, scores = self.hidden_to_idx(output, dropout=self.training)
            y = ys.select(1, i)
            loss += self.criterion(scores, y)  #not averaged
            # use the true token as the next input instead of predicted
            # this produces a biased prediction but better training
            dec_xes = self.lt(y).unsqueeze(0)

            # TODO: overhead!
            for b in range(batchsize):
                # convert the output scores to tokens
                token = self.v2t([preds.data[b]])
                output_lines[b].append(token)

        if self.training:
            self.loss = loss.data[0] / sum(ylen)  # consider non-NULL
            self.ndata += batchsize
        else:
            self.loss_valid += loss.data[0]  # consider non-NULL / accumulate!
            self.ndata_valid += sum(ylen)

        return loss, output_lines

    def _decode_only(self, batchsize, dec_xes, xlen_t, xs, encoder_output):
        # just produce a prediction without training the model
        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0

        output_lines = [[] for _ in range(batchsize)]

        hidden, last_state = self._get_context(batchsize, xlen_t,
                                               encoder_output)

        # now, generate a response from scratch
        while (total_done < batchsize) and max_len < self.longest_label:
            # keep producing tokens until we hit END or max length for each
            if self.use_attention:
                output = self._apply_attention(dec_xes, encoder_output,
                                               hidden[-1], xs)
            else:
                output = torch.cat((dec_xes, last_state.unsqueeze(0)), 2)

            output, hidden = self.decoder(output, hidden)
            preds, scores = self.hidden_to_idx(output, dropout=False)

            #dec_xes = self.lt2dec(self.lt(preds.unsqueeze(0)))
            dec_xes = self.lt(preds).unsqueeze(0)

            max_len += 1
            for b in range(batchsize):
                if not done[b]:
                    # only add more tokens for examples that aren't done yet
                    token = self.v2t([preds.data[b]])
                    if token == self.END:
                        # if we produced END, we're done
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)

        return output_lines

    def _beam_search(self,
                     batchsize,
                     dec_xes,
                     xlen_t,
                     xs,
                     encoder_output,
                     n_best=20):
        # Code borrowed from PyTorch OpenNMT example
        # https://github.com/MaximumEntropy/Seq2Seq-PyTorch/blob/master/decode.py

        print('(beam search {})'.format(self.beamsize))

        # just produce a prediction without training the model
        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0
        output_lines = [[] for _ in range(batchsize)]

        hidden, last_state = self._get_context(
            batchsize, xlen_t,
            encoder_output)  ## hidden = 2(#layer)x1x2048 / last_state = 1x4096

        # exapnd tensors for each beam
        beamsize = self.beamsize
        if not self.use_attention:
            context = Variable(last_state.data.repeat(1, beamsize, 1))

        dec_states = [
            Variable(hidden.data.repeat(1, beamsize, 1))  # 2x3x2048
            #Variable(context_c_t.data.repeat(1, self.beamsize, 1)) ## TODO : GRU OK. check LSTM ?
        ]

        beam = [
            Beam(beamsize, self.dict.tok2ind, cuda=self.use_cuda)
            for k in range(batchsize)
        ]

        batch_idx = list(range(batchsize))
        remaining_sents = batchsize

        input = Variable(dec_xes.data.repeat(1, beamsize, 1))
        encoder_output = Variable(encoder_output.data.repeat(beamsize, 1, 1))

        while max_len < self.max_seq_len:

            # keep producing tokens until we hit END or max length for each
            if self.use_attention:
                output = self._apply_attention(input, encoder_output,
                                               dec_states[0][-1], xs)
            else:
                output = torch.cat((input, context), 2)

            output, hidden = self.decoder(output, dec_states[0])
            preds, scores = self.hidden_to_idx(output, dropout=False)

            dec_states = [hidden]
            word_lk = scores.view(beamsize, remaining_sents,
                                  -1).transpose(0, 1).contiguous()

            active = []
            for b in range(batchsize):
                if beam[b].done:
                    continue

                idx = batch_idx[b]
                #if not beam[b].advance(word_lk.data[idx]):
                #if not beam[b].advance_end(word_lk.data[idx]):
                if not beam[b].advance_diverse(word_lk.data[idx]):
                    active += [b]

                for dec_state in dec_states:  # iterate over h, c
                    # layers x beam*sent x dim
                    sent_states = dec_state.view(-1, beamsize, remaining_sents,
                                                 dec_state.size(2))[:, :, idx]
                    sent_states.data.copy_(
                        sent_states.data.index_select(
                            1, beam[b].get_current_origin()))

            if not active:
                break
            """
            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_idx = torch.cuda.LongTensor([batch_idx[k] for k in active])
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            def update_active(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remaining_sents, self.decoder.hidden_size)
                new_size = list(t.size())
                new_size[-2] = new_size[-2] * len(active_idx) \
                    // remaining_sents
                return Variable(view.index_select( 1, active_idx).view(*new_size))

            pdb.set_trace()
            dec_states = (
                update_active(dec_states[0])#, 2x3x2048  #layer x batch*beam * 2048
                #update_active(dec_states[1])
            )
            #dec_out = update_active(dec_out)
            context = update_active(context) # 1x3x4096

            remaining_sents = len(active)
            pdb.set_trace()            
            """

            input = torch.stack([
                b.get_current_state() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)
            input = self.lt(Variable(input))

            max_len += 1

        all_preds, allScores = [], []
        for b in range(batchsize):  ## TODO :: does it provide batchsize > 1 ?
            hyps = []
            scores, ks = beam[b].sort_best()
            #scores, ks = beam[b].sort_best_normlen()

            allScores += [scores[:self.beamsize]]
            hyps += [beam[b].get_hyp(k) for k in ks[:self.beamsize]]

            all_preds += [
                ' '.join([self.dict.ind2tok[y] for y in x if not y is 0])
                for x in hyps
            ]  # self.dict.null_token = 0

            if n_best == 1:
                print(
                    '\n    input:',
                    self.dict.vec2txt(xs[0].data.cpu()).replace(
                        self.dict.null_token + ' ', ''), '\n    pred :',
                    ''.join(all_preds[b]), '\n')
            else:
                print(
                    '\n    input:',
                    self.dict.vec2txt(xs[0].data.cpu()).replace(
                        self.dict.null_token + ' ', '\n'))
                for hyps in range(len(hyps)):
                    print('   {:3f} '.format(scores[hyps]),
                          ''.join(all_preds[hyps]))

            print('the first: ' +
                  ' '.join([self.dict.ind2tok[y] for y in beam[0].nextYs[1]]))
        return [all_preds[0]], all_preds  # 1-best

    def _score_candidates(self, cands, xe, encoder_output, hidden):
        # score each candidate separately

        # cands are exs_with_cands x cands_per_ex x words_per_cand
        # cview is total_cands x words_per_cand
        cview = cands.view(-1, cands.size(2))
        cands_xes = xe.expand(xe.size(0), cview.size(0), xe.size(2))
        sz = hidden.size()
        cands_hn = (hidden.view(sz[0], sz[1], 1, sz[2]).expand(
            sz[0], sz[1], cands.size(1),
            sz[2]).contiguous().view(sz[0], -1, sz[2]))

        sz = encoder_output.size()
        cands_encoder_output = (encoder_output.contiguous().view(
            sz[0], 1, sz[1],
            sz[2]).expand(sz[0], cands.size(1), sz[1],
                          sz[2]).contiguous().view(-1, sz[1], sz[2]))

        cand_scores = Variable(
            self.cand_scores.resize_(cview.size(0)).fill_(0))
        cand_lengths = Variable(
            self.cand_lengths.resize_(cview.size(0)).fill_(0))

        for i in range(cview.size(1)):
            output = self._apply_attention(cands_xes, cands_encoder_output, cands_hn) \
                    if self.use_attention else cands_xes

            output, cands_hn = self.decoder(output, cands_hn)
            preds, scores = self.hidden_to_idx(output, dropout=False)
            cs = cview.select(1, i)
            non_nulls = cs.ne(self.NULL_IDX)
            cand_lengths += non_nulls.long()
            score_per_cand = torch.gather(scores, 1, cs.unsqueeze(1))
            cand_scores += score_per_cand.squeeze() * non_nulls.float()
            #cands_xes = self.lt2dec(self.lt(cs).unsqueeze(0))
            cands_xes = self.lt(cs).unsqueeze(0)

        # set empty scores to -1, so when divided by 0 they become -inf
        cand_scores -= cand_lengths.eq(0).float()
        # average the scores per token
        cand_scores /= cand_lengths.float()

        cand_scores = cand_scores.view(cands.size(0), cands.size(1))
        srtd_scores, text_cand_inds = cand_scores.sort(1, True)
        text_cand_inds = text_cand_inds.data

        return text_cand_inds

    def predict(self, xs, xlen, ylen=None, ys=None, cands=None):
        """Produce a prediction from our model.

        Update the model using the targets if available, otherwise rank
        candidates as well if they are available.
        """

        self._training(self.training)

        batchsize = len(xs)
        text_cand_inds = None
        target_exist = ys is not None

        xlen_t = torch.LongTensor(xlen) - 1
        if self.use_cuda:
            xlen_t = xlen_t.cuda()
        xlen_t = Variable(xlen_t)

        # Encoding
        encoder_output = self._encode(xs, xlen, dropout=self.training)

        # next we use START as an input to kick off our decoder
        x = Variable(self.START_TENSOR)
        xe = self.lt(x).unsqueeze(1)
        dec_xes = xe.expand(xe.size(0), batchsize, xe.size(2))

        # list of output tokens for each example in the batch
        output_lines = None

        # Decoding
        if not self.generating:
            #if (target_exist is not None) and (self.generating is False):
            loss, output_lines = self._decode_and_train(
                batchsize, dec_xes, xlen_t, xs, ys, ylen, encoder_output)
            if self.training:
                loss.backward()
                if self.opt['grad_clip'] > 0:
                    torch.nn.utils.clip_grad_norm(self.lt.parameters(),
                                                  self.opt['grad_clip'])
                    torch.nn.utils.clip_grad_norm(self.h2o.parameters(),
                                                  self.opt['grad_clip'])
                    torch.nn.utils.clip_grad_norm(self.encoder.parameters(),
                                                  self.opt['grad_clip'])
                    torch.nn.utils.clip_grad_norm(self.decoder.parameters(),
                                                  self.opt['grad_clip'])
                self.update_params()
            self.display_predict(xs, ys, output_lines)

        else:
            #elif not target_exists or self.generating:
            assert (not self.training)
            if cands is not None:
                text_cand_inds = self._score_candidates(
                    cands, xe, encoder_output)

            if self.opt['beam_size'] > 0:
                output_lines, beam_cands = self._beam_search(
                    batchsize, dec_xes, xlen_t, xs, encoder_output)
            else:
                output_lines = self._decode_only(batchsize, dec_xes, xlen_t,
                                                 xs, encoder_output)
                self.display_predict(xs, ys, output_lines, 1)

        return output_lines, text_cand_inds, beam_cands

    def display_predict(self, xs, ys, output_lines, freq=0.01):
        if random.random() < freq:
            # sometimes output a prediction for debugging
            print(
                '\n    input:',
                self.dict.vec2txt(xs[0].data.cpu()).replace(
                    self.dict.null_token + ' ', ''), '\n    pred :',
                ' '.join(output_lines[0]), '\n')
            if ys is not None:
                print(
                    '    label:',
                    self.dict.vec2txt(ys[0].data.cpu()).replace(
                        self.dict.null_token + ' ', ''), '\n')

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        # tokenize the text
        xs = None
        xlen = None
        if batchsize > 0:
            parsed = [
                self.dict.parse(self.START) + self.parse(ex['text']) +
                self.dict.parse(self.END) for ex in exs
            ]
            max_x_len = max([len(x) for x in parsed])
            if self.truncate:
                # shrink xs to to limit batch computation
                max_x_len = min(max_x_len, self.max_seq_len)
                parsed = [x[-max_x_len:] for x in parsed]

            # sorting for unpack in encoder
            parsed_x = sorted(parsed, key=lambda p: len(p), reverse=True)
            xlen = [len(x) for x in parsed_x]
            xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
            """
            # pack the data to the right side of the tensor for this model
            for i, x in enumerate(parsed):
                offset = max_x_len - len(x)
                for j, idx in enumerate(x):
                    xs[i][j + offset] = idx
                    """
            for i, x in enumerate(parsed_x):
                for j, idx in enumerate(x):
                    xs[i][j] = idx
            if self.use_cuda:
                # copy to gpu
                self.xs.resize_(xs.size())
                self.xs.copy_(xs, async=True)
                xs = Variable(self.xs)
            else:
                xs = Variable(xs)

        # set up the target tensors
        ys = None
        ylen = None

        if batchsize > 0 and (any(['labels' in ex for ex in exs])
                              or any(['eval_labels' in ex for ex in exs])):
            # randomly select one of the labels to update on, if multiple
            # append END to each label
            if any(['labels' in ex for ex in exs]):
                labels = [
                    random.choice(ex.get('labels', [''])) + ' ' + self.END
                    for ex in exs
                ]
            else:
                labels = [
                    random.choice(ex.get('eval_labels', [''])) + ' ' + self.END
                    for ex in exs
                ]

            parsed_y = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed_y)
            if self.truncate:
                # shrink ys to to limit batch computation
                max_y_len = min(max_y_len, self.max_seq_len)
                parsed_y = [y[:max_y_len] for y in parsed_y]

            seq_pairs = sorted(zip(parsed, parsed_y),
                               key=lambda p: len(p[0]),
                               reverse=True)
            _, parsed_y = zip(*seq_pairs)

            ylen = [len(x) for x in parsed_y]
            ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed_y):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                # copy to gpu
                self.ys.resize_(ys.size())
                self.ys.copy_(ys, async=True)
                ys = Variable(self.ys)
            else:
                ys = Variable(ys)

        # set up candidates
        cands = None
        valid_cands = None
        if ys is None and self.rank:
            # only do ranking when no targets available and ranking flag set
            parsed = []
            valid_cands = []
            for i in valid_inds:
                if 'label_candidates' in observations[i]:
                    # each candidate tuple is a pair of the parsed version and
                    # the original full string
                    cs = list(observations[i]['label_candidates'])
                    parsed.append([self.parse(c) for c in cs])
                    valid_cands.append((i, cs))
            if len(parsed) > 0:
                # TODO: store lengths of cands separately, so don't have zero
                # padding for varying number of cands per example
                # found cands, pack them into tensor
                max_c_len = max(max(len(c) for c in cs) for cs in parsed)
                max_c_cnt = max(len(cs) for cs in parsed)
                cands = torch.LongTensor(len(parsed), max_c_cnt,
                                         max_c_len).fill_(0)
                for i, cs in enumerate(parsed):
                    for j, c in enumerate(cs):
                        for k, idx in enumerate(c):
                            cands[i][j][k] = idx
                if self.use_cuda:
                    # copy to gpu
                    self.cands.resize_(cands.size())
                    self.cands.copy_(cands, async=True)
                    cands = Variable(self.cands)
                else:
                    cands = Variable(cands)

        return xs, ys, valid_inds, cands, valid_cands, xlen, ylen

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, ys, valid_inds, cands, valid_cands, xlen, ylen = self.batchify(
            observations)

        if xs is None:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions either way, but use the targets if available

        predictions, text_cand_inds, beam_cands = self.predict(
            xs, xlen, ylen, ys, cands)

        for i in range(len(predictions)):
            # map the predictions back to non-empty examples in the batch
            # we join with spaces since we produce tokens one at a time
            curr = batch_reply[valid_inds[i]]
            #curr['text'] = ' '.join(c for c in predictions[i] if c != self.END and c != self.dict.null_token) ## TODO: check!!
            curr['text'] = ''.join(
                c for c in predictions[i] if c != self.END
                and c != self.dict.null_token)  ## TODO: check!!

        if text_cand_inds is not None:
            for i in range(len(valid_cands)):
                order = text_cand_inds[i]
                batch_idx, curr_cands = valid_cands[i]
                curr = batch_reply[batch_idx]
                curr['text_candidates'] = [
                    curr_cands[idx] for idx in order if idx < len(curr_cands)
                ]

        return batch_reply, beam_cands

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def act_beam_cands(self):
        return self.batch_act([self.observation])[1]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path and hasattr(self, 'lt'):
            model = {}
            model['lt'] = self.lt.state_dict()
            #model['lt2enc'] = self.lt2enc.state_dict()
            #model['lt2dec'] = self.lt2dec.state_dict()
            model['encoder'] = self.encoder.state_dict()
            model['decoder'] = self.decoder.state_dict()
            model['h2o'] = self.h2o.state_dict()
            if self.use_attention:
                model['attn'] = self.attn.state_dict()
            model['optims'] = {
                k: v.state_dict()
                for k, v in self.optims.items()
            }
            model['longest_label'] = self.longest_label
            model['opt'] = self.opt

            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            if (self.use_cuda):
                model = torch.load(read)
            else:
                model = torch.load(read,
                                   map_location=lambda storage, loc: storage)
        return model['opt'], model

    def set_states(self, states):
        """Set the state dicts of the modules from saved states."""
        self.lt.load_state_dict(states['lt'])
        self.encoder.load_state_dict(states['encoder'])
        self.decoder.load_state_dict(states['decoder'])
        self.h2o.load_state_dict(states['h2o'])
        if self.use_attention:
            self.attn.load_state_dict(states['attn'])
        for k, v in states['optims'].items():
            self.optims[k].load_state_dict(v)
        self.longest_label = states['longest_label']

    def report(self):
        m = {}
        if not self.generating:
            if self.training:
                m['nll'] = self.loss
                m['ppl'] = math.exp(self.loss)
                m['ndata'] = self.ndata
            else:
                m['nll'] = self.loss_valid / self.ndata_valid
                m['ppl'] = math.exp(self.loss_valid / self.ndata_valid)
                m['ndata'] = self.ndata_valid

            m['lr'] = self.lr
            self.print_weight_state()

        return m

    def reset_valid_report(self):
        self.ndata_valid = 0
        self.loss_valid = 0

    def print_weight_state(self):
        self._print_grad_weight(getattr(self, 'lt').weight, 'lookup')
        for module in {'encoder', 'decoder'}:
            layer = getattr(self, module)
            for weights in layer._all_weights:
                for weight_name in weights:
                    self._print_grad_weight(getattr(layer, weight_name),
                                            module + ' ' + weight_name)
        self._print_grad_weight(getattr(self, 'h2o').weight, 'h2o')
        if self.use_attention:
            self._print_grad_weight(getattr(self, 'attn').weight, 'attn')

    def _print_grad_weight(self, weight, module_name):
        if weight.dim() == 2:
            nparam = weight.size(0) * weight.size(1)
            norm_w = weight.norm(2).pow(2)
            norm_dw = weight.grad.norm(2).pow(2)
            print('{:30}'.format(module_name) +
                  ' {:5} x{:5}'.format(weight.size(0), weight.size(1)) +
                  ' : w {0:.2e} | '.format((norm_w / nparam).sqrt().data[0]) +
                  'dw {0:.2e}'.format((norm_dw / nparam).sqrt().data[0]))

    def _training(self, training=True):
        for module in {'encoder', 'decoder', 'lt', 'h2o', 'attn'}:
            layer = getattr(self, module)
            if layer is not None:
                layer.training = training
コード例 #13
0
class ConvS2SAgent(Agent):

    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        argparser.add_arg('-hs', '--embedding_size', type=int, default=constants.EMBEDDING_SIZE,
            help='size of the embeddings')
        argparser.add_arg('-nel', '--num_encoder_layers', type=int, default=constants.NUM_ENCODER_LAYERS,
            help='number of encoder layers')
        argparser.add_arg('-ndl', '--num_decoder_layers', type=int, default=constants.NUM_DECODER_LAYERS,
                          help='number of decoder layers')
        argparser.add_arg('-ks', '--kernel_size', type=int, default=constants.KERNEL_SIZE,
            help='size of the convolution kernel')
        argparser.add_arg('-lr', '--learning_rate', type=float, default=constants.LEARNING_RATE,
            help='learning rate')
        argparser.add_arg('-dr', '--dropout', type=float, default=0.1,
            help='dropout rate')
        argparser.add_arg('--cuda', action='store_true', default=constants.USE_CUDA,
            help='disable GPUs even if available')
        argparser.add_arg('--gpu', type=int, default=-1,
            help='which GPU device to use')


    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)

        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.set_device(opt['gpu'])

        if not shared:
            self.dict = DictionaryAgent(opt)
            self.id = 'ConvS2S'
            self.EOS = self.dict.end_token
            self.SOS = self.dict.start_token
            self.use_cuda = opt['cuda']

            self.EOS_TENSOR = torch.LongTensor(self.dict.parse(self.EOS))
            self.SOS_TENSOR = torch.LongTensor(self.dict.parse(self.SOS))

            self.kernel_size = opt['kernel_size']
            self.embedding_size = opt['embedding_size']
            self.num_enc_layers = opt['num_encoder_layers']
            self.num_dec_layers = opt['num_decoder_layers']

            self.longest_label = 2
            self.encoder_pad = (self.kernel_size - 1) // 2
            self.decoder_pad = self.kernel_size - 1

            self.criterion = nn.NLLLoss()
            self.embeder = layers.WordEmbeddingGenerator(self.dict.tok2ind,
                                                         embedding_dim=self.embedding_size)
            self.encoder = layers.EncoderStack(self.embedding_size,
                                               2*self.embedding_size,
                                          self.kernel_size,
                                          self.encoder_pad,
                                          self.num_enc_layers)

            self.decoder = layers.DecoderStack(self.embedding_size,
                                               2 * self.embedding_size,
                                               self.kernel_size,
                                               self.decoder_pad,
                                               self.num_dec_layers)

            self.h2o = layers.HiddenToProb(self.embedding_size, len(self.dict))

            lr = opt['learning_rate']
            self.optims = {
                'embeds': optim.Adam(self.embeder.parameters(), lr=lr),
                'encoder': optim.Adam(self.encoder.parameters(), lr=lr),
                'decoder': optim.Adam(self.decoder.parameters(), lr=lr),
                'd2o': optim.Adam(self.h2o.parameters(), lr=lr),
            }
            if self.use_cuda:
                self.cuda()
            if 'model_file' in opt and os.path.isfile(opt['model_file']):
                print('Loading existing model parameters from ' + opt['model_file'])
                self.load(opt['model_file'])
        self.episode_done = True


    def parse(self, text):
        if self.use_cuda:
            return torch.cuda.LongTensor(self.dict.txt2vec(text))
        else:
            return torch.LongTensor(self.dict.txt2vec(text))

    def v2t(self, vec):
        return self.dict.vec2txt(vec)

    def cuda(self):
        self.EOS_TENSOR = self.EOS_TENSOR.cuda(async=True)
        self.SOS_TENSOR = self.SOS_TENSOR.cuda(async=True)
        self.criterion.cuda()
        self.embeder.cuda()
        self.encoder.cuda()
        self.decoder.cuda()
        # self.attention.cuda()
        self.h2o.cuda()

    def zero_grads(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path
        model = {}
        model['embeds'] = self.embeder.state_dict()
        model['encoder'] = self.encoder.state_dict()
        model['decoder'] = self.decoder.state_dict()
        model['d2o'] = self.h2o.state_dict()
        model['longest_label'] = self.longest_label

        with open(path, 'wb') as write:
            torch.save(model, write)

    def load(self, path):
        with open(path, 'rb') as read:
            model = torch.load(read)

        self.embeder.load_state_dict(model['embeds'])
        self.encoder.load_state_dict(model['encoder'])
        self.decoder.load_state_dict(model['decoder'])
        self.h2o.load_state_dict(model['d2o'])
        self.longest_label = model['longest_label']

    def update(self, xs, ys):
        #NOTE: Batchsize is always 1. i.e. One turn.
        # Seq len is the number of words in that turn
        batchsize, seq_len = xs.size()
        encoder_input = self.embeder.forward(xs).permute(0,2,1)

        # first encode context
        encoder_out = self.encoder.forward(encoder_input)
        if len(encoder_out.size()) == 2: #If there is only one word in the input
            encoder_out.unsqueeze(2)

        targets_embedded = self.embeder.forward(ys).permute(0,2,1)
        # start with EOS tensor for all
        x = Variable(self.EOS_TENSOR)
        xe = self.embeder.forward(x).unsqueeze(2)
        decoder_input = targets_embedded
        prev_target = torch.cat((xe, decoder_input), 2)
        prev_target = prev_target.narrow(2, 0, -1) #Removing the last target

        output_lines = [[] for _ in range(batchsize)]
        self.zero_grads()

        # update model
        loss = 0
        self.longest_label = max(self.longest_label, ys.size(1))

        out = self.decoder.forward(decoder_input, prev_target,
                                   encoder_out, encoder_input, batchsize)

        #NOTE: For the linear layer below and loss calculations,
        #      a 'batch' is the number of words in the sentence.
        #This is a hack job. Need to fix #FIXME
        preds, scores = self.h2o.forward(out.squeeze(dim=0).t())
        loss += self.criterion(scores, ys.squeeze())

        for i in range(batchsize):
            for j in range(preds.size(0)):
                token = self.v2t([preds.data[j]])
                output_lines[i].append(token)

        loss.backward()
        self.update_params()

        if random.random() < 0.1:
            true = self.v2t(ys.data[0])
            print('loss:', round(loss.data[0], 2),
                 ' '.join(output_lines[0]), '(true: {})'.format(true))

        return output_lines

    def predict(self, xs):
        batchsize, seq_len = xs.size()
        encoder_input = self.embeder.forward(xs)
        encoder_input = encoder_input.permute(0,2,1)

        # first encode context
        encoder_out = self.encoder.forward(encoder_input)
        if len(encoder_out.size()) == 2: #If there is only one word in the input
            encoder_out.unsqueeze(2)

        # start with EOS tensor for all
        x = Variable(self.EOS_TENSOR)
        if self.use_cuda:
            x = x.cuda(async=True)
        xe = self.embeder.forward(x).unsqueeze(2)
        decoder_input = xe
        prev_target = Variable(torch.zeros(decoder_input.size()))
        if self.use_cuda:
            prev_target = prev_target.cuda(async=True)
        # prev_target = xe
        output_lines = [[] for _ in range(batchsize)]
        done = [False for _ in range(batchsize)]
        total_done = 0
        max_len = 0
        token_count = 0

        while (total_done < batchsize) and max_len < self.longest_label:
            out = self.decoder.forward(decoder_input, prev_target,
                                       encoder_out, encoder_input, batchsize,
                                       predict=True)
            preds, scores = self.h2o.forward(out.squeeze(dim=0).t())
            prev_target = self.embeder.forward(preds).unsqueeze(2)

            decoder_input = torch.cat((decoder_input, prev_target), dim=2)
            token_count += 1
            if token_count > 1: #To ignore the first generated string
                max_len += 1

                for i in range(batchsize):
                    eos_count = 0
                    if not done[i]:
                        token = self.v2t(preds.data)
                        # print('eos_count' , eos_count)
                        if token == self.EOS:
                            done[i] = True
                            total_done += 1
                            token_count = 0
                        else:
                            output_lines[i].append(token)
                            # eos_count+=1

        # if random.random() < 0.1:
        print('prediction:', ' '.join(output_lines[0]))
        return output_lines


    def batchify(self, obs):
        exs = [ex for ex in obs if 'text' in ex]
        valid_inds = [i for i, ex in enumerate(obs) if 'text' in ex]
        batchsize = len(exs)

        parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max([len(x) for x in parsed])
        if self.use_cuda:
            xs = torch.cuda.LongTensor(batchsize, max_x_len).fill_(0)
        else:
            xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
        for i, x in enumerate(parsed):
            offset = max_x_len - len(x)
            for j, idx in enumerate(x):
                xs[i][j + offset] = idx
        if self.use_cuda:
            xs = xs.cuda(async=True)
        xs = Variable(xs)
        ys = None
        if 'labels' in exs[0]:
            labels = [random.choice(ex['labels']) + ' ' + self.EOS for ex in exs]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            if self.use_cuda:
                ys = torch.cuda.LongTensor(batchsize, max_y_len).fill_(0)

            else:
                ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            if self.use_cuda:
                ys = ys.cuda(async=True)
            ys = Variable(ys)
        return xs, ys, valid_inds


    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        xs, ys, valid_inds = self.batchify(observations)
        if len(xs) == 0:
            return batch_reply

        # Either train or predict
        if ys is not None:
            # predictions = self.predict(xs)

            predictions = self.update(xs, ys)
        else:
            predictions = self.predict(xs)

        for i in range(len(predictions)):
            batch_reply[valid_inds[i]]['text'] = ' '.join(
                c for c in predictions[i] if c != self.EOS)

        return batch_reply


    def act(self):
        return self.batch_act([self.observation])[0]


    def observe(self, observation):
        observation = copy.deepcopy(observation)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation
コード例 #14
0
class ScoringNetAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Sequence to Sequence Learning with Neural
    Networks `(Sutskever et al. 2014) <https://arxiv.org/abs/1409.3215>`_.
    """

    OPTIM_OPTS = {
        'adadelta': optim.Adadelta,
        'adagrad': optim.Adagrad,
        'adam': optim.Adam,
        'adamax': optim.Adamax,
        'asgd': optim.ASGD,
        'lbfgs': optim.LBFGS,
        'rmsprop': optim.RMSprop,
        'rprop': optim.Rprop,
        'sgd': optim.SGD,
    }

    ENC_OPTS = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Seq2Seq Arguments')
        agent.add_argument('-hs',
                           '--hiddensize',
                           type=int,
                           default=128,
                           help='size of the hidden layers')
        agent.add_argument('-emb',
                           '--embeddingsize',
                           type=int,
                           default=128,
                           help='size of the token embeddings')
        agent.add_argument('-nl',
                           '--numlayers',
                           type=int,
                           default=2,
                           help='number of hidden layers')
        agent.add_argument('-lr',
                           '--learning_rate',
                           type=float,
                           default=0.5,
                           help='learning rate')
        agent.add_argument('-wd',
                           '--weight_decay',
                           type=float,
                           default=0,
                           help='weight decay')
        agent.add_argument('-dr',
                           '--dropout',
                           type=float,
                           default=0.2,
                           help='dropout rate')
        agent.add_argument('-att',
                           '--attention',
                           default=False,
                           type='bool',
                           help='if True, use attention')
        agent.add_argument(
            '-attType',
            '--attn-type',
            default='general',
            choices=['general', 'concat', 'dot'],
            help='general=bilinear dotproduct, concat=bahdanau\'s implemenation'
        )
        agent.add_argument('--no-cuda',
                           action='store_true',
                           default=False,
                           help='disable GPUs even if available')
        agent.add_argument('--gpu',
                           type=int,
                           default=-1,
                           help='which GPU device to use')
        agent.add_argument('-rc',
                           '--rank-candidates',
                           type='bool',
                           default=False,
                           help='rank candidates if available. this is done by'
                           ' computing the mean score per token for each '
                           'candidate and selecting the highest scoring.')
        agent.add_argument('-tr',
                           '--truncate',
                           type='bool',
                           default=True,
                           help='truncate input & output lengths to speed up '
                           'training (may reduce accuracy). This fixes all '
                           'input and output to have a maximum length and to '
                           'be similar in length to one another by throwing '
                           'away extra tokens. This reduces the total amount '
                           'of padding in the batches.')
        agent.add_argument('-enc',
                           '--encoder',
                           default='gru',
                           choices=ScoringNetAgent.ENC_OPTS.keys(),
                           help='Choose between different encoder modules.')
        agent.add_argument('-bi',
                           '--bi-encoder',
                           default=True,
                           type='bool',
                           help='Bidirection of encoder')
        agent.add_argument('-dec',
                           '--decoder',
                           default='same',
                           choices=['same', 'shared'] +
                           list(ScoringNetAgent.ENC_OPTS.keys()),
                           help='Choose between different decoder modules. '
                           'Default "same" uses same class as encoder, '
                           'while "shared" also uses the same weights.')
        agent.add_argument('-opt',
                           '--optimizer',
                           default='sgd',
                           choices=ScoringNetAgent.OPTIM_OPTS.keys(),
                           help='Choose between pytorch optimizers. '
                           'Any member of torch.optim is valid and will '
                           'be used with default params except learning '
                           'rate (as specified by -lr).')
        agent.add_argument('-gradClip',
                           '--grad-clip',
                           type=float,
                           default=-1,
                           help='gradient clip, default = -1 (no clipping)')
        agent.add_argument(
            '-epi',
            '--episode-concat',
            type='bool',
            default=False,
            help=
            'If multiple observations are from the same episode, concatenate them.'
        )
        agent.add_argument(
            '--beam_size',
            type=int,
            default=0,
            help=
            'Beam size for beam search (only for generation mode) \n For Greedy search set 0'
        )
        agent.add_argument('--max_seq_len',
                           type=int,
                           default=50,
                           help='The maximum sequence length, default = 50')
        agent.add_argument('-ptrmodel',
                           '--ptr_model',
                           default='',
                           help='The pretrained model directory')

    def __init__(self, opt, shared=None):
        """Set up model if shared params not set, otherwise no work to do."""
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            # check for cuda
            self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available(
            )
            if self.use_cuda:
                print('[ Using CUDA ]')
                torch.cuda.set_device(opt['gpu'])
            """
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' + opt['model_file'])
                new_opt, self.states = self.load(opt['model_file'])
                # override options with stored ones
                opt = self.override_opt(new_opt)
            """
            if opt.get('ptr_model') and os.path.isfile(opt['ptr_model']):
                # load model parameters if available
                print('Loading existing model params from ' + opt['ptr_model'])
                new_opt, self.states = self.load(
                    opt['ptr_model'])  ## TODO:: load what?
                # override options with stored ones
                #opt = self.override_opt(new_opt)

            self.dict = DictionaryAgent(opt)
            self.id = 'ScoringNet'
            # we use START markers to start our output
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))
            # we use END markers to end our output
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            # get index of null token from dictionary (probably 0)
            self.NULL_IDX = self.dict.txt2vec(self.dict.null_token)[0]

            # store important params directly
            hsz = opt['hiddensize']
            emb = opt['embeddingsize']
            self.hidden_size = hsz
            self.emb_size = emb
            self.num_layers = opt['numlayers']
            self.learning_rate = opt['learning_rate']
            self.rank = opt['rank_candidates']
            self.longest_label = 1
            self.truncate = opt['truncate']
            self.attention = opt['attention']

            # set up tensors
            if self.opt['bi_encoder']:
                self.zeros = torch.zeros(2 * self.num_layers, 1, hsz)
            else:
                self.zeros = torch.zeros(self.num_layers, 1, hsz)

            self.zeros_dec = torch.zeros(self.num_layers, 1, hsz)

            self.xs = torch.LongTensor(1, 1)
            self.ys = torch.LongTensor(1, 1)
            self.neg_ys = torch.LongTensor(1, 1)

            # set up modules
            #self.criterion = nn.NLLLoss(size_average = False, ignore_index = 0)
            self.criterion = nn.BCELoss()

            # lookup table stores word embeddings
            self.lt = nn.Embedding(len(self.dict),
                                   emb,
                                   padding_idx=self.NULL_IDX)
            #scale_grad_by_freq=True)
            # encoder captures the input text
            enc_class = ScoringNetAgent.ENC_OPTS[opt['encoder']]
            self.encoder = enc_class(emb,
                                     hsz,
                                     opt['numlayers'],
                                     bidirectional=opt['bi_encoder'],
                                     dropout=opt['dropout'])
            # decoder produces our output states

            dec_isz = hsz
            if opt['bi_encoder']:
                dec_isz += hsz

            # linear layer helps us produce outputs from final decoder state
            self.h2o = nn.Linear(dec_isz, dec_isz, bias=False)

            # droput on the linear layer helps us generalize
            self.dropout = nn.Dropout(opt['dropout'])

            self.use_attention = False
            self.attn = None
            # if attention is greater than 0, set up additional members
            if self.attention:
                self.use_attention = True
                self.att_type = opt['attn_type']
                input_size = hsz
                if opt['bi_encoder']:
                    input_size += hsz

                if self.att_type == 'concat':
                    self.attn = nn.Linear(input_size + hsz, 1, bias=False)
                elif self.att_type == 'dot':
                    assert not opt['bi_encoder']
                elif self.att_type == 'general':
                    self.attn = nn.Linear(hsz, input_size, bias=False)

            # set up optims for each module
            self.lr = opt['learning_rate']
            self.wd = opt['weight_decay'] is not 0

            optim_class = ScoringNetAgent.OPTIM_OPTS[opt['optimizer']]
            self.optims = {
                'lt':
                optim_class(self.lt.parameters(), lr=self.lr),
                'encoder':
                optim_class(self.encoder.parameters(), lr=self.lr),
                'h2o':
                optim_class(self.h2o.parameters(),
                            lr=self.lr,
                            weight_decay=self.wd),
            }
            if self.attention and self.attn is not None:
                self.optims.update({
                    'attn':
                    optim_class(self.attn.parameters(),
                                lr=self.lr,
                                weight_decay=self.wd)
                })

            if hasattr(self, 'states'):
                # set loaded states if applicable
                if opt.get('ptr_model'):
                    self.init_pretrain(self.states)
                else:
                    self.set_states(self.states)

            if self.use_cuda:
                self.cuda()

            self.loss = 0
            self.ndata = 0
            self.loss_valid = 0
            self.ndata_valid = 0

            if opt['beam_size'] > 0:
                self.beamsize = opt['beam_size']

        self.episode_concat = opt['episode_concat']
        self.training = True
        self.generating = False
        self.local_human = False
        self.max_seq_len = opt['max_seq_len']
        self.reset()

    def set_lrate(self, lr):
        self.lr = lr
        for key in self.optims:
            self.optims[key].param_groups[0]['lr'] = self.lr

    def override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'hiddensize', 'embeddingsize', 'numlayers', 'optimizer', 'encoder'
        }
        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def parse(self, text):
        """Convert string to token indices."""
        return self.dict.txt2vec(text)

    def v2t(self, vec):
        """Convert token indices to string of tokens."""
        return self.dict.vec2txt(vec)

    def cuda(self):
        """Push parameters to the GPU."""
        self.START_TENSOR = self.START_TENSOR.cuda(async=True)
        self.END_TENSOR = self.END_TENSOR.cuda(async=True)
        self.zeros = self.zeros.cuda(async=True)
        self.zeros_dec = self.zeros_dec.cuda(async=True)
        self.xs = self.xs.cuda(async=True)
        self.ys = self.ys.cuda(async=True)
        self.neg_ys = self.neg_ys.cuda(async=True)
        self.criterion.cuda()
        self.lt.cuda()
        self.encoder.cuda()
        self.h2o.cuda()
        self.dropout.cuda()
        if self.use_attention:
            self.attn.cuda()

    def hidden_to_idx(self, hidden, dropout=False):
        """Convert hidden state vectors into indices into the dictionary."""
        if hidden.size(0) > 1:
            raise RuntimeError('bad dimensions of tensor:', hidden)
        hidden = hidden.squeeze(0)
        if dropout:
            hidden = self.dropout(hidden)  # dropout over the last hidden
        scores = self.h2o(hidden)
        scores = F.log_softmax(scores)
        _max_score, idx = scores.max(1)
        return idx, scores

    def zero_grad(self):
        """Zero out optimizers."""
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        """Do one optimization step."""
        for optimizer in self.optims.values():
            optimizer.step()

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def preprocess(self, reply_text):
        # preprocess for opensub
        reply_text = reply_text.replace('\\n', '\n')  ## TODO: pre-processing
        reply_text = reply_text.replace("'m", " 'm")
        reply_text = reply_text.replace("'ve", " 've")
        reply_text = reply_text.replace("'s", " 's")
        reply_text = reply_text.replace("'t", " 't")
        reply_text = reply_text.replace("'il", " 'il")
        reply_text = reply_text.replace("'d", " 'd")
        reply_text = reply_text.replace("'re", " 're")
        reply_text = reply_text.lower().strip()

        return reply_text

    def observe(self, observation):
        """Save observation for act.
        If multiple observations are from the same episode, concatenate them.
        """
        if self.local_human:
            observation = {}
            observation['id'] = self.getID()
            reply_text = input("Enter Your Message: ")
            reply_text = self.preprocess(reply_text)
            observation['episode_done'] = True  ### TODO: for history
            observation['text'] = reply_text

            reply_text = input("Enter a lable: ")
            observation['labels'] = self.preprocess(reply_text)

            reply_text = input("Enter a candidate: ")
            observation['cands'] = self.preprocess(reply_text)

        else:
            # shallow copy observation (deep copy can be expensive)
            observation = observation.copy()
            if not self.episode_done and self.episode_concat:
                # if the last example wasn't the end of an episode, then we need to
                # recall what was said in that example
                prev_dialogue = self.observation['text']
                observation['text'] = prev_dialogue + '\n' + observation[
                    'text']  #### TODO!!!! # DATA is concatenated!!

        self.observation = observation
        self.episode_done = observation['episode_done']

        return observation

    def _encode(self, xs, xlen, dropout=False, packed=True):
        """Call encoder and return output and hidden states."""
        batchsize = len(xs)

        # first encode context
        xes = self.lt(xs).transpose(0, 1)
        #if dropout:
        #    xes = self.dropout(xes)

        # initial hidden
        if self.zeros.size(1) != batchsize:
            if self.opt['bi_encoder']:
                self.zeros.resize_(2 * self.num_layers, batchsize,
                                   self.hidden_size).fill_(0)
            else:
                self.zeros.resize_(self.num_layers, batchsize,
                                   self.hidden_size).fill_(0)

        h0 = Variable(self.zeros.fill_(0))

        # forward
        if packed:
            xes = torch.nn.utils.rnn.pack_padded_sequence(xes, xlen)

        if type(self.encoder) == nn.LSTM:
            encoder_output, _ = self.encoder(
                xes, (h0, h0))  ## Note : we can put None instead of (h0, h0)
        else:
            encoder_output, _ = self.encoder(xes, h0)

        if packed:
            encoder_output, _ = torch.nn.utils.rnn.pad_packed_sequence(
                encoder_output)

        encoder_output = encoder_output.transpose(0, 1)  #batch first
        """
        if self.use_attention:
            if encoder_output.size(1) > self.max_length:
                offset = encoder_output.size(1) - self.max_length
                encoder_output = encoder_output.narrow(1, offset, self.max_length)
        """

        return encoder_output

    def _apply_attention(self, word_input, encoder_output, last_hidden, xs):
        """Apply attention to encoder hidden layer."""
        batch_size = encoder_output.size(0)
        enc_length = encoder_output.size(1)
        mask = Variable(xs.data.eq(0).eq(0).float())

        #pdb.set_trace()
        # encoder_output # B x T x 2H
        # last_hidden  B x H

        if self.att_type == 'concat':
            last_hidden = last_hidden.unsqueeze(1).expand(
                batch_size, encoder_output.size(1),
                self.hidden_size)  # B x T x H
            attn_weights = F.tanh(
                self.attn(
                    torch.cat((encoder_output, last_hidden),
                              2).view(batch_size * enc_length,
                                      -1)).view(batch_size, enc_length))
        elif self.att_type == 'dot':
            attn_weights = F.tanh(
                torch.bmm(encoder_output, last_hidden.unsqueeze(2)).squeeze())
        elif self.att_type == 'general':
            attn_weights = F.tanh(
                torch.bmm(encoder_output,
                          self.attn(last_hidden).unsqueeze(2)).squeeze())

        #attn_weights = F.softmax(attn_weights.view(batch_size, enc_length))

        attn_weights = attn_weights.exp().mul(mask)
        denom = attn_weights.sum(1).unsqueeze(1).expand_as(attn_weights)
        attn_weights = attn_weights.div(denom)
        context = torch.bmm(attn_weights.unsqueeze(1),
                            encoder_output).squeeze(1)

        output = torch.cat((word_input, context.unsqueeze(0)), 2)
        return output

    def _get_context(self, batchsize, xlen_t, encoder_output):
        " return initial hidden of decoder and encoder context (last_state)"

        ## The initial of decoder is the hidden (last states) of encoder --> put zero!
        if self.zeros_dec.size(1) != batchsize:
            self.zeros_dec.resize_(self.num_layers, batchsize,
                                   self.hidden_size).fill_(0)
        hidden = Variable(self.zeros_dec.fill_(0))

        last_state = None
        if not self.use_attention:
            last_state = torch.gather(
                encoder_output, 1,
                xlen_t.view(-1, 1, 1).expand(encoder_output.size(0), 1,
                                             encoder_output.size(2)))
            if self.opt['bi_encoder']:
                last_state = torch.cat(
                    (encoder_output[:, 0, self.hidden_size:],
                     last_state[:, 0, :self.hidden_size]), 1)

        return hidden, last_state

    def predict(self,
                xs,
                xlen,
                x_idx,
                ys,
                ylen,
                y_idx,
                nys=None,
                nylen=None,
                ny_idx=None):
        """Produce a prediction from our model.

        Update the model using the targets if available, otherwise rank
        candidates as well if they are available.
        """

        self._training(self.training)
        self.zero_grad()

        batchsize = len(xs)
        #text_cand_inds = None
        #target_exist = ys is not None

        xlen_t = Variable(torch.LongTensor(xlen) - 1)
        ylen_t = Variable(torch.LongTensor(ylen) - 1)
        if self.use_cuda:
            xlen_t = xlen_t.cuda()
            ylen_t = ylen_t.cuda()

        _, x_idx_t = torch.LongTensor(x_idx).sort(0)
        _, y_idx_t = torch.LongTensor(y_idx).sort(0)

        if self.use_cuda:
            x_idx_t = x_idx_t.cuda()
            y_idx_t = y_idx_t.cuda()
        if ny_idx is not None:
            nylen_t = Variable(torch.LongTensor(nylen) - 1)
            _, ny_idx_t = torch.LongTensor(ny_idx).sort(0)

            if self.use_cuda:
                nylen_t = nylen_t.cuda()
                ny_idx_t = ny_idx_t.cuda()

        # Encoding
        _, enc_x = self._get_context(
            batchsize, xlen_t, self._encode(xs, xlen,
                                            dropout=self.training))  # encode x
        _, enc_y = self._get_context(
            batchsize, ylen_t, self._encode(ys, ylen,
                                            dropout=self.training))  # encode x

        # Permute
        enc_x = enc_x[x_idx_t, :]
        enc_y = enc_y[y_idx_t, :]
        target = Variable(torch.Tensor(batchsize).zero_())

        if ny_idx is not None:
            _, enc_ny = self._get_context(
                batchsize, nylen_t,
                self._encode(nys, nylen, dropout=self.training))  # encode x
            enc_ny = enc_ny[ny_idx_t, :]

            # make batch
            enc_x = torch.cat((enc_x, enc_x), 0)
            enc_y = torch.cat((enc_y, enc_ny), 0)
            target = torch.cat((target, target + 1), 0)

        if self.use_cuda:
            target = target.cuda()

        # calcuate the score
        output = F.sigmoid(
            torch.bmm(enc_y.unsqueeze(1),
                      self.h2o(enc_x).unsqueeze(1).transpose(1, 2)))

        # loss
        loss = self.criterion(output.squeeze(), target)

        if self.training:
            self.ndata += batchsize
            self.loss = loss
        else:
            self.ndata_valid += batchsize
            self.loss_valid += loss.data[0] * batchsize

        # list of output tokens for each example in the batch
        if self.training:
            self.loss.backward()
            if self.opt['grad_clip'] > 0:
                torch.nn.utils.clip_grad_norm(self.lt.parameters(),
                                              self.opt['grad_clip'])
                torch.nn.utils.clip_grad_norm(self.h2o.parameters(),
                                              self.opt['grad_clip'])
                torch.nn.utils.clip_grad_norm(self.encoder.parameters(),
                                              self.opt['grad_clip'])
            self.update_params()

            self.display_predict(xs[x_idx_t[0], :],
                                 ys[y_idx_t[0], :],
                                 nys[ny_idx_t[0], :],
                                 target,
                                 output,
                                 batchsize,
                                 freq=0.05)

        return self.loss, output.squeeze()

    def display_predict(self,
                        xs,
                        ys,
                        nys,
                        target,
                        output,
                        batchsize,
                        freq=0.01):
        if random.random() < freq:
            # sometimes output a prediction for debugging
            print(
                '\n    input:',
                self.dict.vec2txt(xs.data.cpu()).replace(
                    self.dict.null_token + ' ', ''), '\n    postive:',
                ' {0:.2e} '.format(output[0].data.cpu()[0, 0]),
                self.dict.vec2txt(ys.data.cpu()).replace(
                    self.dict.null_token + ' ', ''), '\n    negative:',
                ' {0:.2e} '.format(output[batchsize].data.cpu()[0, 0]),
                self.dict.vec2txt(nys.data.cpu()).replace(
                    self.dict.null_token + ' ', ''), '\n')

    def txt2tensor(self, parsed, batchsize):
        max_x_len = max([len(x) for x in parsed])
        if self.truncate:
            # shrink xs to to limit batch computation
            max_x_len = min(max_x_len, self.max_seq_len)
            parsed = [x[-max_x_len:] for x in parsed]

        # sorting for unpack in encoder
        parsed_x = sorted(enumerate(parsed),
                          key=lambda p: len(p[1]),
                          reverse=True)
        x_idx, parsed_x = zip(*parsed_x)
        x_idx = list(x_idx)
        xlen = [len(x) for x in parsed_x]
        xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
        for i, x in enumerate(parsed_x):
            for j, idx in enumerate(x):
                xs[i][j] = idx
        if self.use_cuda:
            # copy to gpu
            self.xs.resize_(xs.size())
            self.xs.copy_(xs, async=True)
            xs = Variable(self.xs)
        else:
            xs = Variable(xs)

        return xs, xlen, x_idx

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)

        # tokenize the text
        xs = None
        xlen = None
        x_idx = None
        if batchsize > 0:
            parsed = [
                self.dict.parse(self.START) + self.parse(ex['text']) +
                self.dict.parse(self.END) for ex in exs
            ]
            xs, xlen, x_idx = self.txt2tensor(parsed, batchsize)

        # set up the target tensors (positive exampels)
        ys = None
        ylen = None
        y_idx = None

        if batchsize > 0 and (any(['labels' in ex for ex in exs])
                              or any(['eval_labels' in ex for ex in exs])):
            # randomly select one of the labels to update on, if multiple
            # append END to each label
            if any(['labels' in ex for ex in exs]):
                labels = [
                    self.START + ' ' + random.choice(ex.get('labels', [''])) +
                    ' ' + self.END for ex in exs
                ]
            else:
                labels = [
                    self.START + ' ' +
                    random.choice(ex.get('eval_labels', [''])) + ' ' + self.END
                    for ex in exs
                ]

            parsed_y = [self.parse(y) for y in labels]
            ys, ylen, y_idx = self.txt2tensor(parsed_y, batchsize)

        # set up candidates (negative samples, randomly select!!)
        neg_ys = None
        neg_ylen = None
        ny_idx = None

        if batchsize > 0:
            cands = None
            for i in range(len(exs)):
                if exs[i].get('label_candidates') is not None:
                    cands = list(exs[i]['label_candidates'])
                    break
            if cands is None:
                if any(['labels' in ex for ex in exs]):
                    cands = [ex['labels'][0] for ex in exs
                             ]  ## TODO: the same index should not be selected
                else:
                    cands = [ex['eval_labels'][0] for ex in exs
                             ]  ## TODO: the same index should not be selected

            # randomly select one of the labels to update on, if multiple
            # append END to each label
            parsed_ny = [
                self.dict.parse(self.START) +
                self.parse(random.choice(cands)) + self.dict.parse(self.END)
                for ex in exs
            ]
            neg_ys, neg_ylen, ny_idx = self.txt2tensor(parsed_ny, batchsize)

        return xs, xlen, x_idx, ys, ylen, y_idx, valid_inds, neg_ys, neg_ylen, ny_idx

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, xlen, x_idx, ys, ylen, y_idx, valid_inds, neg_ys, neg_ylen, ny_idx = self.batchify(
            observations)

        if xs is None:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        ## seperate : test code / train code
        loss = self.predict(xs, xlen, x_idx, ys, ylen, y_idx, neg_ys, neg_ylen,
                            ny_idx)

        return batch_reply

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def act_scoring_test(self):  ## see ../../bot_code/CC_scoring.py

        x = self.observation['text']
        y = self.observation['labels']
        batchsize = len(x)

        parsed = [
            self.dict.parse(self.START) + self.parse(ex) +
            self.dict.parse(self.END) for ex in x
        ]
        xs, xlen, x_idx = self.txt2tensor(parsed, batchsize)

        labels = [
            self.dict.parse(self.START) + self.parse(ex) +
            self.dict.parse(self.END) for ex in y
        ]
        ys, ylen, y_idx = self.txt2tensor(labels, batchsize)

        loss, output = self.predict(xs, xlen, x_idx, ys, ylen, y_idx)
        return output.data

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path and hasattr(self, 'lt'):
            model = {}
            model['lt'] = self.lt.state_dict()
            model['encoder'] = self.encoder.state_dict()
            model['h2o'] = self.h2o.state_dict()
            if self.use_attention:
                model['attn'] = self.attn.state_dict()
            model['optims'] = {
                k: v.state_dict()
                for k, v in self.optims.items()
            }
            model['longest_label'] = self.longest_label
            model['opt'] = self.opt

            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)

        return model['opt'], model

    def set_states(self, states):
        """Set the state dicts of the modules from saved states."""
        self.lt.load_state_dict(states['lt'])
        self.encoder.load_state_dict(states['encoder'])
        #self.h2o.load_state_dict(states['h2o'])
        if self.use_attention:
            self.attn.load_state_dict(states['attn'])
        for k, v in states['optims'].items():
            self.optims[k].load_state_dict(v)
        self.longest_label = states['longest_label']

    def init_pretrain(self, states):
        """Set the state dicts of the modules from saved states."""
        self.lt.load_state_dict(states['lt'])
        self.encoder.load_state_dict(states['encoder'])
        #self.h2o.load_state_dict(states['h2o'])
        """
        if self.use_attention:
            self.attn.load_state_dict(states['attn'])
        for k, v in states['optims'].items():
            self.optims[k].load_state_dict(v)
        self.longest_label = states['longest_label']
        """

    def report(self):
        m = {}
        if not self.generating:
            if self.training:
                m['loss'] = self.loss.data[0]
                m['ndata'] = self.ndata
            else:
                m['loss'] = self.loss_valid / self.ndata_valid
                m['ndata'] = self.ndata_valid

            m['lr'] = self.lr
            self.print_weight_state()

        return m

    def reset_valid_report(self):
        self.ndata_valid = 0
        self.loss_valid = 0

    def print_weight_state(self):
        self._print_grad_weight(getattr(self, 'lt').weight, 'lookup')
        for module in {'encoder'}:
            layer = getattr(self, module)
            for weights in layer._all_weights:
                for weight_name in weights:
                    self._print_grad_weight(getattr(layer, weight_name),
                                            module + ' ' + weight_name)
        self._print_grad_weight(getattr(self, 'h2o').weight, 'h2o')
        if self.use_attention:
            self._print_grad_weight(getattr(self, 'attn').weight, 'attn')

    def _print_grad_weight(self, weight, module_name):
        if weight.dim() == 2:
            nparam = weight.size(0) * weight.size(1)
            norm_w = weight.norm(2).pow(2)
            norm_dw = weight.grad.norm(2).pow(2)
            print('{:30}'.format(module_name) +
                  ' {:5} x{:5}'.format(weight.size(0), weight.size(1)) +
                  ' : w {0:.2e} | '.format((norm_w / nparam).sqrt().data[0]) +
                  'dw {0:.2e}'.format((norm_dw / nparam).sqrt().data[0]))

    def _training(self, training=True):
        for module in {'encoder', 'lt', 'h2o', 'attn'}:
            layer = getattr(self, module)
            if layer is not None:
                layer.training = training
コード例 #15
0
ファイル: memnn.py プロジェクト: jojonki/ParlAI
class MemnnAgent(Agent):
    """ Memory Network agent.
    """

    @staticmethod
    def add_cmdline_args(argparser):
        DictionaryAgent.add_cmdline_args(argparser)
        arg_group = argparser.add_argument_group('MemNN Arguments')
        arg_group.add_argument('-lr', '--learning-rate', type=float, default=0.01,
            help='learning rate')
        arg_group.add_argument('--embedding-size', type=int, default=128,
            help='size of token embeddings')
        arg_group.add_argument('--hops', type=int, default=3,
            help='number of memory hops')
        arg_group.add_argument('--mem-size', type=int, default=100,
            help='size of memory')
        arg_group.add_argument('--time-features', type='bool', default=True,
            help='use time features for memory embeddings')
        arg_group.add_argument('--position-encoding', type='bool', default=False,
            help='use position encoding instead of bag of words embedding')
        arg_group.add_argument('--output', type=str, default='rank',
            help='type of output (rank|generate)')
        arg_group.add_argument('--rnn-layers', type=int, default=2,
            help='number of hidden layers in RNN decoder for generative output')
        arg_group.add_argument('--dropout', type=float, default=0.1,
            help='dropout probability for RNN decoder training')
        arg_group.add_argument('--optimizer', default='adam',
            help='optimizer type (sgd|adam)')
        arg_group.add_argument('--no-cuda', action='store_true', default=False,
            help='disable GPUs even if available')
        arg_group.add_argument('--gpu', type=int, default=-1,
            help='which GPU device to use')

    def __init__(self, opt, shared=None):
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.device(opt['gpu'])

        if not shared:
            self.opt = opt
            self.id = 'MemNN'
            self.dict = DictionaryAgent(opt)
            self.answers = [None] * opt['batchsize']

            self.model = MemNN(opt, self.dict)
            self.mem_size = opt['mem_size']
            self.loss_fn = CrossEntropyLoss()

            self.decoder = None
            self.longest_label = 1
            self.END = self.dict.end_token
            self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END))
            self.START = self.dict.start_token
            self.START_TENSOR = torch.LongTensor(self.dict.parse(self.START))
            if opt['output'] == 'generate' or opt['output'] == 'g':
                self.decoder = Decoder(opt['embedding_size'], opt['embedding_size'],
                                        opt['rnn_layers'], opt, self.dict)
            elif opt['output'] != 'rank' and opt['output'] != 'r':
                raise NotImplementedError('Output type not supported.')

            optim_params = [p for p in self.model.parameters() if p.requires_grad]
            lr = opt['learning_rate']
            if opt['optimizer'] == 'sgd':
                self.optimizers = {'memnn': optim.SGD(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.SGD(self.decoder.parameters(), lr=lr)
            elif opt['optimizer'] == 'adam':
                self.optimizers = {'memnn': optim.Adam(optim_params, lr=lr)}
                if self.decoder is not None:
                    self.optimizers['decoder'] = optim.Adam(self.decoder.parameters(), lr=lr)
            else:
                raise NotImplementedError('Optimizer not supported.')

            if opt['cuda']:
                self.model.share_memory()
                if self.decoder is not None:
                    self.decoder.cuda()

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                print('Loading existing model parameters from ' + opt['model_file'])
                self.load(opt['model_file'])
        else:
            self.answers = shared['answers']

        self.episode_done = True
        self.last_cands, self.last_cands_list = None, None
        super().__init__(opt, shared)

    def share(self):
        shared = super().share()
        shared['answers'] = self.answers
        return shared

    def observe(self, observation):
        observation = copy.copy(observation)
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text'] if self.observation is not None else ''
            batch_idx = self.opt.get('batchindex', 0)
            if self.answers[batch_idx] is not None:
                prev_dialogue += '\n' + self.answers[batch_idx]
                self.answers[batch_idx] = None
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def predict(self, xs, cands, ys=None):
        is_training = ys is not None
        self.model.train(mode=is_training)
        # Organize inputs for network (see contents of xs and ys in batchify method)
        inputs = [Variable(x, volatile=is_training) for x in xs]
        output_embeddings = self.model(*inputs)

        if self.decoder is None:
            scores = self.score(cands, output_embeddings)
            if is_training:
                label_inds = [cand_list.index(self.labels[i]) for i, cand_list in enumerate(cands)]
                if self.opt['cuda']:
                    label_inds = Variable(torch.cuda.LongTensor(label_inds))
                else:
                    label_inds = Variable(torch.LongTensor(label_inds))
                loss = self.loss_fn(scores, label_inds)
            predictions = self.ranked_predictions(cands, scores)
        else:
            self.decoder.train(mode=is_training)

            output_lines, loss = self.decode(output_embeddings, ys)
            predictions = self.generated_predictions(output_lines)

        if is_training:
            for o in self.optimizers.values():
                o.zero_grad()
            loss.backward()
            for o in self.optimizers.values():
                o.step()
        return predictions

    def score(self, cands, output_embeddings):
        last_cand = None
        max_len = max([len(c) for c in cands])
        scores = Variable(output_embeddings.data.new(len(cands), max_len))
        for i, cand_list in enumerate(cands):
            if last_cand != cand_list:
                candidate_lengths, candidate_indices = to_tensors(cand_list, self.dict)
                candidate_lengths, candidate_indices = Variable(candidate_lengths), Variable(candidate_indices)
                candidate_embeddings = self.model.answer_embedder(candidate_lengths, candidate_indices)
                if self.opt['cuda']:
                    candidate_embeddings = candidate_embeddings.cuda()
                last_cand = cand_list
            scores[i, :len(cand_list)] = self.model.score.one_to_many(output_embeddings[i].unsqueeze(0), candidate_embeddings)
        return scores

    def ranked_predictions(self, cands, scores):
        _, inds = scores.data.sort(descending=True, dim=1)
        return [[cands[i][j] for j in r if j < len(cands[i])]
                    for i, r in enumerate(inds)]

    def decode(self, output_embeddings, ys=None):
        batchsize = output_embeddings.size(0)
        hn = output_embeddings.unsqueeze(0).expand(self.opt['rnn_layers'], batchsize, output_embeddings.size(1))
        x = self.model.answer_embedder(Variable(torch.LongTensor([1])), Variable(self.START_TENSOR))
        xes = x.unsqueeze(1).expand(x.size(0), batchsize, x.size(1))

        loss = 0
        output_lines = [[] for _ in range(batchsize)]
        done = [False for _ in range(batchsize)]
        total_done = 0
        idx = 0
        while(total_done < batchsize) and idx < self.longest_label:
            # keep producing tokens until we hit END or max length for each ex
            if self.opt['cuda']:
                xes = xes.cuda()
                hn = hn.contiguous()
            preds, scores = self.decoder(xes, hn)
            if ys is not None:
                y = Variable(ys[0][:, idx])
                temp_y = y.cuda() if self.opt['cuda'] else y
                loss += self.loss_fn(scores, temp_y)
            else:
                y = preds
            # use the true token as the next input for better training
            xes = self.model.answer_embedder(Variable(torch.LongTensor(preds.numel()).fill_(1)), y).unsqueeze(0)

            for b in range(batchsize):
                if not done[b]:
                    token = self.dict.vec2txt(preds.data[b])
                    if token == self.END:
                        done[b] = True
                        total_done += 1
                    else:
                        output_lines[b].append(token)
            idx += 1
        return output_lines, loss

    def generated_predictions(self, output_lines):
        return [[' '.join(c for c in o if c != self.END
                        and c != self.dict.null_token)] for o in output_lines]

    def parse(self, text):
        """Returns:
            query = tensor (vector) of token indices for query
            query_length = length of query
            memory = tensor (matrix) where each row contains token indices for a memory
            memory_lengths = tensor (vector) with lengths of each memory
        """
        sp = text.split('\n')
        query_sentence = sp[-1]
        query = self.dict.txt2vec(query_sentence)
        query = torch.LongTensor(query)
        query_length = torch.LongTensor([len(query)])

        sp = sp[:-1]
        sentences = []
        for s in sp:
            sentences.extend(s.split('\t'))
        if len(sentences) == 0:
            sentences.append(self.dict.null_token)

        num_mems = min(self.mem_size, len(sentences))
        memory_sentences = sentences[-num_mems:]
        memory = [self.dict.txt2vec(s) for s in memory_sentences]
        memory = [torch.LongTensor(m) for m in memory]
        memory_lengths = torch.LongTensor([len(m) for m in memory])
        memory = torch.cat(memory)
        return (query, memory, query_length, memory_lengths)

    def batchify(self, obs):
        """Returns:
            xs = [memories, queries, memory_lengths, query_lengths]
            ys = [labels, label_lengths] (if available, else None)
            cands = list of candidates for each example in batch
            valid_inds = list of indices for examples with valid observations
        """
        exs = [ex for ex in obs if 'text' in ex]
        valid_inds = [i for i, ex in enumerate(obs) if 'text' in ex]
        if not exs:
            return [None] * 4

        parsed = [self.parse(ex['text']) for ex in exs]
        queries = torch.cat([x[0] for x in parsed])
        memories = torch.cat([x[1] for x in parsed])
        query_lengths = torch.cat([x[2] for x in parsed])
        memory_lengths = torch.LongTensor(len(exs), self.mem_size).zero_()
        for i in range(len(exs)):
            if len(parsed[i][3]) > 0:
                memory_lengths[i, -len(parsed[i][3]):] = parsed[i][3]
        xs = [memories, queries, memory_lengths, query_lengths]

        ys = None
        self.labels = [random.choice(ex['labels']) for ex in exs if 'labels' in ex]
        if len(self.labels) == len(exs):
            parsed = [self.dict.txt2vec(l) for l in self.labels]
            parsed = [torch.LongTensor(p) for p in parsed]
            label_lengths = torch.LongTensor([len(p) for p in parsed]).unsqueeze(1)
            self.longest_label = max(self.longest_label, label_lengths.max())
            padded = [torch.cat((p, torch.LongTensor(self.longest_label - len(p))
                        .fill_(self.END_TENSOR[0]))) for p in parsed]
            labels = torch.stack(padded)
            ys = [labels, label_lengths]

        cands = [ex['label_candidates'] for ex in exs if 'label_candidates' in ex]
        # Use words in dict as candidates if no candidates are provided
        if len(cands) < len(exs):
            cands = build_cands(exs, self.dict)
        # Avoid rebuilding candidate list every batch if its the same
        if self.last_cands != cands:
            self.last_cands = cands
            self.last_cands_list = [list(c) for c in cands]
        cands = self.last_cands_list
        return xs, ys, cands, valid_inds

    def batch_act(self, observations):
        batchsize = len(observations)
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        xs, ys, cands, valid_inds = self.batchify(observations)

        if xs is None or len(xs[1]) == 0:
            return batch_reply

        # Either train or predict
        predictions = self.predict(xs, cands, ys)

        for i in range(len(valid_inds)):
            self.answers[valid_inds[i]] = predictions[i][0]
            batch_reply[valid_inds[i]]['text'] = predictions[i][0]
            batch_reply[valid_inds[i]]['text_candidates'] = predictions[i]
        return batch_reply

    def act(self):
        return self.batch_act([self.observation])[0]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['memnn'] = self.model.state_dict()
            checkpoint['memnn_optim'] = self.optimizers['memnn'].state_dict()
            if self.decoder is not None:
                checkpoint['decoder'] = self.decoder.state_dict()
                checkpoint['decoder_optim'] = self.optimizers['decoder'].state_dict()
                checkpoint['longest_label'] = self.longest_label
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)

    def load(self, path):
        with open(path, 'rb') as read:
            checkpoint = torch.load(read)
        self.model.load_state_dict(checkpoint['memnn'])
        self.optimizers['memnn'].load_state_dict(checkpoint['memnn_optim'])
        if self.decoder is not None:
            self.decoder.load_state_dict(checkpoint['decoder'])
            self.optimizers['decoder'].load_state_dict(checkpoint['decoder_optim'])
            self.longest_label = checkpoint['longest_label']