Пример #1
0
 def __call__(self, x, batch_idxs, actor_ids, story_stop_idxs):
     actor_ids = actor_ids.clamp(0, self.cast_size - 1)
     idxs = self.convert_idxs(batch_idxs, actor_ids)
     if batch_idxs.dim() > 0:
         selected = self._get_actors(idxs)
         self.count += 1
         new_selected = self.gru(x, selected)
         self._update_actors(idxs, new_selected)
         if story_stop_idxs.dim() > 0:
             self.state = InPlaceZeroDetach(story_stop_idxs)(self.state)
         if self.count % self.bptt_len == 0:
             self.state = self.state.detach()
         return new_selected
Пример #2
0
 def init_state(self, train = True):
   volatile = not train
   self.state = Variable(torch.zeros(self.batch_size, self.n_units).cuda(), volatile = volatile)
   self.zero_out = torch.zeros(self.n_units).cuda()
   self.hsm.init_state(train)
   self.sentence_embedder.init_state(train)
   self.actor_pool.init_state(train)
Пример #3
0
 def __call__(self, x, batch_idxs, actor_ids, story_stop_idxs):
     self.count += 1
     actor_ids.clamp_(0, self.cast_size - 1)
     new_h = None
     if batch_idxs.dim() > 0:
         lt_idxs = torch.stack([batch_idxs, actor_ids[batch_idxs]]).t()
         select_read = SelectND(lt_idxs)
         select_update = SelectiveUpdate(lt_idxs)
         selected = select_read(self.state)
         new_selected = self.gru(x, selected)
         new_state = select_update(self.state, new_selected)
         if story_stop_idxs.dim() > 0:
             self.state = InPlaceZeroDetach(story_stop_idxs)(self.state)
         if self.count % 10 == 0:
             self.state = new_state.detach()
         return new_selected
Пример #4
0
class FixedCastActorHolder(nn.Module):
    def __init__(self, input_size, hidden_size, cast_size, batch_size,
                 bptt_len):
        super(FixedCastActorHolder, self).__init__()
        self.n_units = hidden_size
        self.gru = nn.GRUCell(input_size=input_size, hidden_size=self.n_units)
        self.cast_size = cast_size
        self.batch_size = batch_size
        self.bptt_len = bptt_len
        self.counts = torch.zeros(batch_size, self.cast_size).int()
        self.count = 0

    def init_state(self, train=True):
        volatile = not train
        self.state = Variable(torch.zeros(self.batch_size, self.cast_size,
                                          self.n_units).cuda(),
                              volatile=volatile)

    def __call__(self, x, batch_idxs, actor_ids, story_stop_idxs):
        self.count += 1
        actor_ids.clamp_(0, self.cast_size - 1)
        new_h = None
        if batch_idxs.dim() > 0:
            lt_idxs = torch.stack([batch_idxs, actor_ids[batch_idxs]]).t()
            select_read = SelectND(lt_idxs)
            select_update = SelectiveUpdate(lt_idxs)
            selected = select_read(self.state)
            new_selected = self.gru(x, selected)
            new_state = select_update(self.state, new_selected)
            if story_stop_idxs.dim() > 0:
                self.state = InPlaceZeroDetach(story_stop_idxs)(self.state)
            if self.count % 10 == 0:
                self.state = new_state.detach()
            return new_selected

    def attend(self, query_vectors):
        dots = self.state.bmm(query_vectors.unsqueeze(2)).squeeze()
        weights = F.softmax(F.threshold(dots, 0.0001, -1000.0))
        attended = weights.unsqueeze(1).bmm(self.state).squeeze()
        return attended

    def argmax(self, query_vectors):
        dots = self.state.bmm(query_vectors.unsqueeze(2)).squeeze()
        _, idxs = dots.max(1)
        idxs = idxs.data
Пример #5
0
 def end_stories(self, story_ends):
   self.state = InPlaceZeroDetach(story_ends)(self.state)
   self.actor_pool.end_stories(story_ends)
Пример #6
0
class ActorGRU(nn.Module):

  def __init__(self, params):
    self.__dict__.update(params)
    super(ActorGRU, self).__init__()
    self.embed = nn.Embedding(
      num_embeddings = self.n_vocab + 3,
      embedding_dim = self.word_dim,
    )
    self.gru = nn.GRUCell(
      input_size = self.word_dim + self.actor_n_units,
      hidden_size = self.n_units
    )
    self.dropout = nn.Dropout(
      p = self.drop_prob,
      inplace=False
    )
    self.bn = nn.BatchNorm1d(
      num_features = self.n_units
    )
    self.hsm = HSM(
      input_size = self.n_units + self.actor_n_units,
      #n_vocab = self.n_vocab
      huff_tree = self.huffman_tree,
    )
    self.sentence_embedder = SentenceEmbedder(
      batch_size = self.batch_size,
      vec_dim = self.word_dim
    )
    self.actor_pool = FixedSizeActorPool( #FixedCastActorHolder(
      input_size = self.word_dim,
      query_size = self.n_units,
      hidden_size = self.actor_n_units,
      cast_size = self.actor_cast_size,
      batch_size = self.batch_size,
      bptt_len = self.actor_bptt_len
    )
    self.batch_range = torch.arange(0, self.batch_size).long().cuda()

  def init_params(self):
    self.hsm.init_params()

  def init_state(self, train = True):
    volatile = not train
    self.state = Variable(torch.zeros(self.batch_size, self.n_units).cuda(), volatile = volatile)
    self.zero_out = torch.zeros(self.n_units).cuda()
    self.hsm.init_state(train)
    self.sentence_embedder.init_state(train)
    self.actor_pool.init_state(train)

  def stop_bptt(self):
    self.state.detach_()

  def concat_with_mentioned_actor_states(self, x, actor_locs, actor_ids):
    x = torch.cat([x, Variable(torch.zeros(x.size()[0], self.actor_n_units).cuda(), requires_grad=False)], 1)   
    if actor_locs.dim() > 0:
      lt_idxs = torch.stack([actor_locs, actor_ids[actor_locs]]).t()
      selector = SelectND(lt_idxs)
      mentioned_actor_states = selector(self.actor_pool.state)
      cws = ConcatWithSelection(actor_locs, self.word_dim)
      x = cws(x, mentioned_actor_states)
    return x

  def end_stories(self, story_ends):
    self.state = InPlaceZeroDetach(story_ends)(self.state)
    self.actor_pool.end_stories(story_ends)

  def forward(self, batch):
    wX, wY, am_idX, am_idY, am_locX, am_locY, sam_loc, sam_id, se_loc, sentence_end, story_end = [batch[k].pin_memory().cuda(async=True)
      for k in ['wX', 'wY', 'am_idX', 'am_idY', 'am_locX', 'am_locY', 'sam_loc', 'sam_id', 'se_loc', 'sentence_end', 'story_end']]
Пример #7
0
 def init_state(self, train=True):
     volatile = not train
     self.state = Variable(torch.zeros(self.batch_size, self.cast_size,
                                       self.n_units).cuda(),
                           volatile=volatile)
Пример #8
0
class ActorPool(nn.Module):
    def __init__(self, **params):
        self.__dict__.update(params)

        super(ActorPool, self).__init__()
        self.gru = nn.GRUCell(input_size=self.input_size,
                              hidden_size=self.hidden_size)
        try:
            self.attn_query = nn.Linear(in_features=self.query_size,
                                        out_features=self.hidden_size)
            self.mention_query = nn.Linear(in_features=self.query_size,
                                           out_features=self.hidden_size)
        except AttributeError:
            pass

    def init_state(self, train=True):
        self.count = 0

    def stop_bptt(self):
        self.state.detach_()

    def end_stories(self, story_ends):
        self.state = InPlaceZeroDetach(story_ends)(self.state)

    def convert_idxs(self, batch_idxs, actor_ids):
        raise NotImplementedError

    def get_actors(self, batch_idxs, actor_ids):
        return self._get_actors(self.convert_idxs(batch_idxs, actor_ids))

    def _get_actors(self, idxs):
        raise NotImplementedError

    def update_actors(self, batch_idxs, actor_ids, obs):
        self._update_actors(self.convert_idxs(batch_idxs, actor_ids), obs)

    def _update_actors(self, idxs, obs):
        raise NotImplementedError

    def __call__(self, x, batch_idxs, actor_ids, story_stop_idxs):
        actor_ids = actor_ids.clamp(0, self.cast_size - 1)
        idxs = self.convert_idxs(batch_idxs, actor_ids)
        if batch_idxs.dim() > 0:
            selected = self._get_actors(idxs)
            self.count += 1
            new_selected = self.gru(x, selected)
            self._update_actors(idxs, new_selected)
            if story_stop_idxs.dim() > 0:
                self.state = InPlaceZeroDetach(story_stop_idxs)(self.state)
            if self.count % self.bptt_len == 0:
                self.state = self.state.detach()
            return new_selected

    def attn_weights(self, query_vectors):
        raise NotImplementedError

    def attend(self, query_vectors):
        raise NotImplementedError

    def argmax(self, query_vectors):
        raise NotImplementedError