def default_args(**kwargs): args = {} # Required Args args['model_dim'] = 10 args['word_embedding_dim'] = 12 args['vocab_size'] = 14 args['embedding_keep_rate'] = 1.0 args['classifier_keep_rate'] = 1.0 args['mlp_dim'] = 16 args['num_mlp_layers'] = 2 args['num_classes'] = 3 args['use_sentence_pair'] = False initial_embeddings = np.arange(args['vocab_size']).repeat( args['word_embedding_dim']).reshape(args['vocab_size'], -1).astype(np.float32) args['initial_embeddings'] = initial_embeddings # Tracker Args args['tracking_lstm_hidden_dim'] = 4 args['transition_weight'] = None # Layers context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.encoder = nn.Linear(args['word_embedding_dim'], args['model_dim']) context_args.input_dim = args['model_dim'] args['context_args'] = context_args class Reduce(nn.Module): def forward(self, lefts, rights, tracking): batch_size = len(lefts) return torch.chunk( torch.cat(lefts, 0) - torch.cat(rights, 0), batch_size, 0) composition_args = Args() composition_args.lateral_tracking = True composition_args.use_tracking_in_composition = True composition_args.size = args['model_dim'] composition_args.tracker_size = args['tracking_lstm_hidden_dim'] composition_args.transition_weight = args['transition_weight'] composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.composition = Reduce() composition_args.tracking_ln = False args['composition_args'] = composition_args for k in kwargs.keys(): args[k] = kwargs[k] return args
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "ChoiPyramid": build_model = spinn.choi_pyramid.build_model elif FLAGS.model_type == "Maillard": build_model = spinn.maillard_pyramid.build_model elif FLAGS.model_type == "LMS": build_model = spinn.lms.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() if FLAGS.model_type == "LMS": intermediate_dim = FLAGS.model_dim * FLAGS.model_dim else: intermediate_dim = FLAGS.model_dim if FLAGS.encode == "projection": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x encoder = Linear()(FLAGS.word_embedding_dim, intermediate_dim) context_args.input_dim = intermediate_dim elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, intermediate_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse, mix=(FLAGS.model_type != "CBOW")) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = IntraAttention(FLAGS.word_embedding_dim, intermediate_dim) elif FLAGS.encode == "pass": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' assert FLAGS.model_type != 'LMS', 'Must use reduce=lms for LMS.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim // 2 composition = ReduceTreeLSTM( FLAGS.model_dim // 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) elif FLAGS.reduce == "lms": composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim composition = ReduceTensor(FLAGS.model_dim) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model
def __init__(self, model_dim=None, word_embedding_dim=None, vocab_size=None, initial_embeddings=None, num_classes=None, mlp_dim=None, embedding_keep_rate=None, classifier_keep_rate=None, tracking_lstm_hidden_dim=4, transition_weight=None, encode_style=None, encode_reverse=None, encode_bidirectional=None, encode_num_layers=None, use_skips=False, lateral_tracking=None, use_tracking_in_composition=None, use_sentence_pair=False, use_difference_feature=False, use_product_feature=False, num_mlp_layers=None, mlp_bn=None, use_projection=None, **kwargs ): super(BaseModel, self).__init__() self.use_sentence_pair = use_sentence_pair self.use_difference_feature = use_difference_feature self.use_product_feature = use_product_feature self.hidden_dim = hidden_dim = model_dim / 2 args = Args() args.lateral_tracking = lateral_tracking args.use_tracking_in_composition = use_tracking_in_composition args.size = model_dim/2 args.tracker_size = tracking_lstm_hidden_dim args.transition_weight = transition_weight self.initial_embeddings = initial_embeddings self.word_embedding_dim = word_embedding_dim self.model_dim = model_dim classifier_dropout_rate = 1. - classifier_keep_rate vocab = Vocab() vocab.size = initial_embeddings.shape[0] if initial_embeddings is not None else vocab_size vocab.vectors = initial_embeddings # Build parsing component. self.spinn = self.build_spinn(args, vocab, use_skips) # Build classiifer. features_dim = self.get_features_dim() self.mlp = MLP(features_dim, mlp_dim, num_classes, num_mlp_layers, mlp_bn, classifier_dropout_rate) # The input embeddings represent the hidden and cell state, so multiply by 2. self.embedding_dropout_rate = 1. - embedding_keep_rate input_embedding_dim = args.size * 2 # Projection will effectively be done by the encoding network. use_projection = True if encode_style is None else False # Create dynamic embedding layer. self.embed = Embed(input_embedding_dim, vocab.size, vectors=vocab.vectors, use_projection=use_projection) # Optionally build input encoder. if encode_style is not None: self.encode = self.build_input_encoder(encode_style=encode_style, word_embedding_dim=word_embedding_dim, model_dim=model_dim, num_layers=encode_num_layers, bidirectional=encode_bidirectional, reverse=encode_reverse, dropout=self.embedding_dropout_rate)
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "Pyramid": build_model = spinn.pyramid.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim if FLAGS.encode == "projection": encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "pass": def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.detach = FLAGS.transition_detach composition_args.evolution = FLAGS.evolution if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim / 2 composition = ReduceTreeLSTM(FLAGS.model_dim / 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. if FLAGS.evolution: trainer = ModelTrainer_ES(model, optimizer) else: trainer = ModelTrainer(model, optimizer) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model, optimizer, trainer
def __init__(self, model_dim=None, word_embedding_dim=None, vocab_size=None, initial_embeddings=None, num_classes=None, mlp_dim=None, embedding_keep_rate=None, classifier_keep_rate=None, tracking_lstm_hidden_dim=4, transition_weight=None, use_encode=None, encode_reverse=None, encode_bidirectional=None, encode_num_layers=None, use_skips=False, lateral_tracking=None, use_tracking_in_composition=None, # use_sentence_pair=False, use_difference_feature=False, use_product_feature=False, num_mlp_layers=None, mlp_bn=None, model_specific_params={}, **kwargs ): super(SentencePairModel, self).__init__() logger.info('ATTSPINN SentencePairModel init...') # self.use_sentence_pair = use_sentence_pair self.use_difference_feature = use_difference_feature self.use_product_feature = use_product_feature self.hidden_dim = hidden_dim = model_dim / 2 # features_dim = hidden_dim * 2 if use_sentence_pair else hidden_dim features_dim = model_dim # [premise, hypothesis, diff, product] if self.use_difference_feature: features_dim += self.hidden_dim if self.use_product_feature: features_dim += self.hidden_dim mlp_input_dim = features_dim self.initial_embeddings = initial_embeddings self.word_embedding_dim = word_embedding_dim self.model_dim = model_dim classifier_dropout_rate = 1. - classifier_keep_rate args = Args() args.lateral_tracking = lateral_tracking args.use_tracking_in_composition = use_tracking_in_composition args.size = model_dim/2 args.tracker_size = tracking_lstm_hidden_dim args.transition_weight = transition_weight args.using_diff_in_mlstm = model_specific_params['using_diff_in_mlstm'] args.using_prod_in_mlstm = model_specific_params['using_prod_in_mlstm'] args.using_null_in_attention = model_specific_params['using_null_in_attention'] vocab = Vocab() vocab.size = initial_embeddings.shape[0] if initial_embeddings is not None else vocab_size vocab.vectors = initial_embeddings # The input embeddings represent the hidden and cell state, so multiply by 2. self.embedding_dropout_rate = 1. - embedding_keep_rate input_embedding_dim = args.size * 2 # Create dynamic embedding layer. self.embed = Embed(input_embedding_dim, vocab.size, vectors=vocab.vectors) self.use_encode = use_encode if use_encode: self.encode_reverse = encode_reverse self.encode_bidirectional = encode_bidirectional self.bi = 2 if self.encode_bidirectional else 1 self.encode_num_layers = encode_num_layers self.encode = nn.LSTM(model_dim, model_dim / self.bi, num_layers=encode_num_layers, batch_first=True, bidirectional=self.encode_bidirectional, dropout=self.embedding_dropout_rate) self.spinn = self.build_spinn(args, vocab, use_skips) self.attention = self.build_attention(args) self.mlp = MLP(mlp_input_dim, mlp_dim, num_classes, num_mlp_layers, mlp_bn, classifier_dropout_rate)