def states(self, state_iter): if state_iter is not None: state = bundle(state_iter) self.c, self.h = state.c, state.h
def run(self, inp_transitions, run_internal_parser=False, use_internal_parser=False, validate_transitions=True): transition_loss = None transition_acc = 0.0 num_transitions = inp_transitions.shape[1] # Transition Loop # =============== for t_step in range(num_transitions): transitions = inp_transitions[:, t_step] transition_arr = list(transitions) sub_batch_size = len(transition_arr) # A mask to select all non-SKIP transitions. cant_skip = np.array([t != T_SKIP for t in transitions]) # Remember important details from this time step. self.memory = {} # Run if: # A. We have a tracking component and, # B. There is at least one transition that will not be skipped. if hasattr(self, 'tracker') and (self.use_skips or sum(cant_skip) > 0): # Prepare tracker input. try: top_buf = bundle(buf[-1] for buf in self.bufs) top_stack_1 = bundle(stack[-1] for stack in self.stacks) top_stack_2 = bundle(stack[-2] for stack in self.stacks) except: # To elaborate on this exception, when cropping examples it is possible # that your first 1 or 2 actions is a reduce action. It is unclear if this # is a bug in cropping or a bug in how we think about cropping. In the meantime, # turn on the truncate batch flag, and set the eval_seq_length very high. raise NotImplementedError("Warning: You are probably trying to encode examples" "with cropped transitions. Although, this is a reasonable" "feature, when predicting/validating transitions, you" "probably will not get the behavior that you expect. Disable" "this exception if you dare.") # Uncomment to handle weirdly placed actions like discussed in the above exception. # ========= # zeros = to_gpu(Variable(torch.from_numpy( # np.zeros(self.bufs[0][0].size(), dtype=np.float32)), # volatile=self.bufs[0][0].volatile)) # top_buf = bundle(buf[-1] for buf in self.bufs) # top_stack_1 = bundle(stack[-1] if len(stack) > 0 else zeros for stack in self.stacks) # top_stack_2 = bundle(stack[-2] if len(stack) > 1 else zeros for stack in self.stacks) # Get hidden output from the tracker. Used to predict transitions. tracker_h, tracker_c = self.tracker(top_buf, top_stack_1, top_stack_2) if hasattr(self, 'transition_net'): transition_output = self.transition_net(tracker_h) if hasattr(self, 'transition_net') and run_internal_parser: # Predict Actions # =============== t_logits = F.log_softmax(transition_output) t_given = transitions # TODO: Mask before predicting. This should simplify things and reduce computation. # The downside is that in the Action Phase, need to be smarter about which stacks/bufs # are selected. transition_preds = self.predict_actions(transition_output, cant_skip) # Constrain to valid actions # ========================== if validate_transitions: transition_preds = self.validate(transition_arr, transition_preds, self.stacks, self.bufs) t_preds = transition_preds # Indices of examples that have a transition. t_mask = np.arange(sub_batch_size) # Filter to non-SKIP values # ========================= if not self.use_skips: t_preds = t_preds[cant_skip] t_given = t_given[cant_skip] t_mask = t_mask[cant_skip] # Be careful when filtering distributions. These values are used to # calculate loss and need to be used in backprop. index = (cant_skip * np.arange(cant_skip.shape[0]))[cant_skip] index = to_gpu(Variable(torch.from_numpy(index).long(), volatile=t_logits.volatile)) t_logits = torch.index_select(t_logits, 0, index) # Memories # ======== # Keep track of key values to determine accuracy and loss. # (optional) Filter to only non-skipped transitions. When filtering values # that will be backpropagated over, be careful that gradient flow isn't broken. # Actual transition predictions. Used to measure transition accuracy. self.memory["t_preds"] = t_preds # Distribution of transitions use to calculate transition loss. self.memory["t_logits"] = t_logits # Given transitions. self.memory["t_given"] = t_given # Record step index. self.memory["t_mask"] = t_mask # TODO: Write tests to make sure memories look right in the various settings. # If this FLAG is set, then use the predicted actions rather than the given. if use_internal_parser: transition_arr = transition_preds.tolist() # Pre-Action Phase # ================ # For SHIFT s_stacks, s_tops, s_trackings, s_idxs = [], [], [], [] # For REDUCE r_stacks, r_lefts, r_rights, r_trackings, r_idxs = [], [], [], [], [] batch = zip(transition_arr, self.bufs, self.stacks, self.tracker.states if hasattr(self, 'tracker') and self.tracker.h is not None else itertools.repeat(None)) for batch_idx, (transition, buf, stack, tracking) in enumerate(batch): if transition == T_SHIFT: # shift self.t_shift(buf, stack, tracking, s_tops, s_trackings) s_idxs.append(batch_idx) s_stacks.append(stack) elif transition == T_REDUCE: # reduce self.t_reduce(buf, stack, tracking, r_lefts, r_rights, r_trackings) r_stacks.append(stack) r_idxs.append(batch_idx) elif transition == T_SKIP: # skip self.t_skip() # Action Phase # ============ self.shift_phase(s_tops, s_trackings, s_stacks, s_idxs) self.shift_phase_hook(s_tops, s_trackings, s_stacks, s_idxs) self.reduce_phase(r_lefts, r_rights, r_trackings, r_stacks) self.reduce_phase_hook(r_lefts, r_rights, r_trackings, r_stacks, r_idxs=r_idxs) # Memory Phase # ============ self.memories.append(self.memory) # Loss Phase # ========== if hasattr(self, 'tracker') and hasattr(self, 'transition_net'): t_preds, t_logits, t_given, _ = self.get_statistics() # We compute accuracy and loss after all transitions have complete, # since examples can have different lengths when not using skips. transition_acc = (t_preds == t_given).sum() / float(t_preds.shape[0]) transition_loss = nn.NLLLoss()(t_logits, to_gpu(Variable( torch.from_numpy(t_given), volatile=t_logits.volatile))) transition_loss *= self.transition_weight self.loss_phase_hook() if self.debug: assert all(len(stack) == 3 for stack in self.stacks), \ "Stacks should be fully reduced and have 3 elements: " \ "two zeros and the sentence encoding." assert all(len(buf) == 1 for buf in self.bufs), \ "Stacks should be fully shifted and have 1 zero." return [stack[-1] for stack in self.stacks], transition_acc, transition_loss
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "ChoiPyramid": build_model = spinn.choi_pyramid.build_model elif FLAGS.model_type == "Maillard": build_model = spinn.maillard_pyramid.build_model elif FLAGS.model_type == "LMS": build_model = spinn.lms.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() if FLAGS.model_type == "LMS": intermediate_dim = FLAGS.model_dim * FLAGS.model_dim else: intermediate_dim = FLAGS.model_dim if FLAGS.encode == "projection": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x encoder = Linear()(FLAGS.word_embedding_dim, intermediate_dim) context_args.input_dim = intermediate_dim elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, intermediate_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse, mix=(FLAGS.model_type != "CBOW")) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = IntraAttention(FLAGS.word_embedding_dim, intermediate_dim) elif FLAGS.encode == "pass": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' assert FLAGS.model_type != 'LMS', 'Must use reduce=lms for LMS.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim // 2 composition = ReduceTreeLSTM( FLAGS.model_dim // 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) elif FLAGS.reduce == "lms": composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim composition = ReduceTensor(FLAGS.model_dim) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "Pyramid": build_model = spinn.pyramid.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim if FLAGS.encode == "projection": encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "pass": def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.detach = FLAGS.transition_detach composition_args.evolution = FLAGS.evolution if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim / 2 composition = ReduceTreeLSTM(FLAGS.model_dim / 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. if FLAGS.evolution: trainer = ModelTrainer_ES(model, optimizer) else: trainer = ModelTrainer(model, optimizer) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model, optimizer, trainer