def default_args(**kwargs): args = {} # Required Args args['model_dim'] = 10 args['word_embedding_dim'] = 12 args['vocab_size'] = 14 args['embedding_keep_rate'] = 1.0 args['classifier_keep_rate'] = 1.0 args['mlp_dim'] = 16 args['num_mlp_layers'] = 2 args['num_classes'] = 3 args['use_sentence_pair'] = False initial_embeddings = np.arange(args['vocab_size']).repeat( args['word_embedding_dim']).reshape(args['vocab_size'], -1).astype(np.float32) args['initial_embeddings'] = initial_embeddings # Tracker Args args['tracking_lstm_hidden_dim'] = 4 args['transition_weight'] = None # Layers context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.encoder = nn.Linear(args['word_embedding_dim'], args['model_dim']) context_args.input_dim = args['model_dim'] args['context_args'] = context_args class Reduce(nn.Module): def forward(self, lefts, rights, tracking): batch_size = len(lefts) return torch.chunk( torch.cat(lefts, 0) - torch.cat(rights, 0), batch_size, 0) composition_args = Args() composition_args.lateral_tracking = True composition_args.use_tracking_in_composition = True composition_args.size = args['model_dim'] composition_args.tracker_size = args['tracking_lstm_hidden_dim'] composition_args.transition_weight = args['transition_weight'] composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.composition = Reduce() composition_args.tracking_ln = False args['composition_args'] = composition_args for k in kwargs.keys(): args[k] = kwargs[k] return args
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "ChoiPyramid": build_model = spinn.choi_pyramid.build_model elif FLAGS.model_type == "Maillard": build_model = spinn.maillard_pyramid.build_model elif FLAGS.model_type == "LMS": build_model = spinn.lms.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() if FLAGS.model_type == "LMS": intermediate_dim = FLAGS.model_dim * FLAGS.model_dim else: intermediate_dim = FLAGS.model_dim if FLAGS.encode == "projection": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x encoder = Linear()(FLAGS.word_embedding_dim, intermediate_dim) context_args.input_dim = intermediate_dim elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, intermediate_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse, mix=(FLAGS.model_type != "CBOW")) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = intermediate_dim encoder = IntraAttention(FLAGS.word_embedding_dim, intermediate_dim) elif FLAGS.encode == "pass": context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' assert FLAGS.model_type != 'LMS', 'Must use reduce=lms for LMS.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim // 2 composition = ReduceTreeLSTM( FLAGS.model_dim // 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) elif FLAGS.reduce == "lms": composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim composition = ReduceTensor(FLAGS.model_dim) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "Pyramid": build_model = spinn.pyramid.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim if FLAGS.encode == "projection": encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "pass": def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.detach = FLAGS.transition_detach composition_args.evolution = FLAGS.evolution if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim / 2 composition = ReduceTreeLSTM(FLAGS.model_dim / 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. if FLAGS.evolution: trainer = ModelTrainer_ES(model, optimizer) else: trainer = ModelTrainer(model, optimizer) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model, optimizer, trainer