def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "ChoiPyramid": build_model = spinn.choi_pyramid.build_model elif FLAGS.model_type == "Maillard": build_model = spinn.maillard_pyramid.build_model elif FLAGS.model_type == "LMS": build_model = spinn.lms.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() if FLAGS.model_type == "LMS": intermediate_dim = FLAGS.model_dim * FLAGS.model_dim else: intermediate_dim = FLAGS.model_dim if FLAGS.encode == "projection": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x encoder = Linear()(FLAGS.word_embedding_dim, intermediate_dim) context_args.input_dim = intermediate_dim elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, intermediate_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse, mix=(FLAGS.model_type != "CBOW")) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = IntraAttention(FLAGS.word_embedding_dim, intermediate_dim) elif FLAGS.encode == "pass": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' assert FLAGS.model_type != 'LMS', 'Must use reduce=lms for LMS.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim // 2 composition = ReduceTreeLSTM( FLAGS.model_dim // 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) elif FLAGS.reduce == "lms": composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim composition = ReduceTensor(FLAGS.model_dim) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "Pyramid": build_model = spinn.pyramid.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim if FLAGS.encode == "projection": encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "pass": def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.detach = FLAGS.transition_detach composition_args.evolution = FLAGS.evolution if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim / 2 composition = ReduceTreeLSTM(FLAGS.model_dim / 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. if FLAGS.evolution: trainer = ModelTrainer_ES(model, optimizer) else: trainer = ModelTrainer(model, optimizer) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model, optimizer, trainer