Пример #1
0
    def __init__(self,
                 word_dim,
                 hidden_dim,
                 low_dim,
                 multiproc,
                 composition_ln=False,
                 trainable_temperature=False,
                 parent_selection=False,
                 use_sentence_pair=False):
        super(ChartParser, self).__init__()
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.low_dim = low_dim
        self.multiproc = multiproc
        self.use_sentence_pair = use_sentence_pair

        self.treelstm_layer = BinaryTreeLSTMLayer(
            low_dim,
            composition_ln=composition_ln)  #CAT: low_dim from hidden_dim
        self.parent_selection = parent_selection

        self.cat = Catalan()
        self.reduce_dim = Linear()(in_features=hidden_dim,
                                   out_features=low_dim)

        # TODO: Add something to blocks to make this use case more elegant.
        self.comp_query = Linear()(in_features=low_dim, out_features=1)

        self.trainable_temperature = trainable_temperature
        if self.trainable_temperature:
            self.temperature_param = nn.Parameter(torch.ones(1, 1),
                                                  requires_grad=True)
Пример #2
0
 def __init__(self, size, out_dim, relu_size, tracker_size):
     # Initialize layersi.
     super(RLAction2, self).__init__()
     self.relu_size = 200
     self.tracker_l = Linear()(tracker_size, out_dim, bias=False)
     self.ll_after = Linear()(out_dim, self.relu_size, bias=True)
     self.post_relu = Linear()(self.relu_size, 2, bias=True)
Пример #3
0
    def __init__(self,
                 size,
                 tracker_size,
                 lateral_tracking=True,
                 tracking_ln=True):
        '''Args:
            size: input size (parser hidden state) = FLAGS.model_dim
            tracker_size: FLAGS.tracking_lstm_hidden_dim
            (see FLAGS for the rest)'''
        super(Tracker, self).__init__()

        # Initialize layers.
        if lateral_tracking:
            self.buf = Linear()(size, 4 * tracker_size, bias=True)
            self.stack1 = Linear()(size, 4 * tracker_size, bias=False)
            self.stack2 = Linear()(size, 4 * tracker_size, bias=False)
            self.lateral = Linear(initializer=HeKaimingInitializer)(
                tracker_size, 4 * tracker_size, bias=False)
            self.state_size = tracker_size
        else:
            self.state_size = size * 3

        if tracking_ln:
            self.buf_ln = LayerNormalization(size)
            self.stack1_ln = LayerNormalization(size)
            self.stack2_ln = LayerNormalization(size)

        self.lateral_tracking = lateral_tracking
        self.tracking_ln = tracking_ln

        self.reset_state()
Пример #4
0
    def __init__(self,
                 word_dim,
                 hidden_dim,
                 intra_attention,
                 composition_ln=False,
                 trainable_temperature=False,
                 right_branching=False,
                 debug_branching=False,
                 uniform_branching=False,
                 random_branching=False,
                 st_gumbel=False):
        super(BinaryTreeLSTM, self).__init__()
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.intra_attention = intra_attention
        self.treelstm_layer = BinaryTreeLSTMLayer(
            hidden_dim, composition_ln=composition_ln)
        self.right_branching = right_branching
        self.debug_branching = debug_branching
        self.uniform_branching = uniform_branching
        self.random_branching = random_branching
        self.st_gumbel = st_gumbel

        # TODO: Add something to blocks to make this use case more elegant.
        self.comp_query = Linear()(in_features=hidden_dim, out_features=1)

        self.trainable_temperature = trainable_temperature
        if self.trainable_temperature:
            self.temperature_param = nn.Parameter(torch.ones(1, 1),
                                                  requires_grad=True)
Пример #5
0
    def __init__(self, args, vocab, predict_use_cell):
        super(SPINN, self).__init__()

        # Optional debug mode.
        self.debug = False

        self.transition_weight = args.transition_weight

        self.wrap_items = args.wrap_items
        self.extract_h = args.extract_h

        # Reduce function for semantic composition.
        self.reduce = args.composition
        if args.tracker_size is not None or args.use_internal_parser:
            self.tracker = Tracker(args.size,
                                   args.tracker_size,
                                   lateral_tracking=args.lateral_tracking,
                                   tracking_ln=args.tracking_ln)
            if args.transition_weight is not None:
                # TODO: Might be interesting to try a different network here.
                self.predict_use_cell = predict_use_cell
                if self.tracker.lateral_tracking:
                    tinp_size = self.tracker.state_size * \
                        2 if predict_use_cell else self.tracker.state_size
                else:
                    tinp_size = self.tracker.state_size
                self.transition_net = Linear()(tinp_size, 2)

        self.choices = np.array([T_SHIFT, T_REDUCE], dtype=np.int32)

        self.shift_probabilities = ShiftProbabilities()
Пример #6
0
    def __init__(self, size, tracker_size, lateral_tracking=True):
        super(Tracker, self).__init__()

        # Initialize layers.
        self.buf = Linear()(size, 4 * tracker_size, bias=False)
        self.stack1 = Linear()(size, 4 * tracker_size, bias=False)
        self.stack2 = Linear()(size, 4 * tracker_size, bias=False)

        if lateral_tracking:
            self.lateral = Linear(initializer=HeKaimingInitializer)(tracker_size, 4 * tracker_size)
        else:
            self.transform = Linear(initializer=HeKaimingInitializer)(4 * tracker_size, tracker_size)

        self.lateral_tracking = lateral_tracking
        self.state_size = tracker_size

        self.reset_state()
Пример #7
0
 def __init__(self, hidden_dim, composition_ln=False):
     super(BinaryTreeLSTMLayer, self).__init__()
     self.hidden_dim = hidden_dim
     self.comp_linear = Linear(initializer=HeKaimingInitializer)(
         in_features=2 * hidden_dim, out_features=5 * hidden_dim)
     self.composition_ln = composition_ln
     if composition_ln:
         self.left_h_ln = LayerNormalization(hidden_dim)
         self.right_h_ln = LayerNormalization(hidden_dim)
         self.left_c_ln = LayerNormalization(hidden_dim)
         self.right_c_ln = LayerNormalization(hidden_dim)
Пример #8
0
    def __init__(self,
                 model_dim=None,
                 word_embedding_dim=None,
                 vocab_size=None,
                 initial_embeddings=None,
                 num_classes=None,
                 embedding_keep_rate=None,
                 use_sentence_pair=False,
                 classifier_keep_rate=None,
                 mlp_dim=None,
                 num_mlp_layers=None,
                 mlp_ln=None,
                 context_args=None,
                 gated=None,
                 selection_keep_rate=None,
                 pyramid_selection_keep_rate=None,
                 **kwargs):
        super(Pyramid, self).__init__()

        self.use_sentence_pair = use_sentence_pair
        self.model_dim = model_dim
        self.gated = gated
        self.selection_keep_rate = selection_keep_rate

        classifier_dropout_rate = 1. - classifier_keep_rate

        args = Args()
        args.size = model_dim
        args.input_dropout_rate = 1. - embedding_keep_rate

        vocab = Vocab()
        vocab.size = initial_embeddings.shape[
            0] if initial_embeddings is not None else vocab_size
        vocab.vectors = initial_embeddings

        self.embed = Embed(word_embedding_dim,
                           vocab.size,
                           vectors=vocab.vectors)

        self.composition_fn = SimpleTreeLSTM(model_dim / 2,
                                             composition_ln=False)
        self.selection_fn = Linear(initializer=HeKaimingInitializer)(model_dim,
                                                                     1)

        # TODO: Set up layer norm.

        mlp_input_dim = model_dim * 2 if use_sentence_pair else model_dim

        self.mlp = MLP(mlp_input_dim, mlp_dim, num_classes, num_mlp_layers,
                       mlp_ln, classifier_dropout_rate)

        self.encode = context_args.encoder
        self.reshape_input = context_args.reshape_input
        self.reshape_context = context_args.reshape_context
Пример #9
0
    def __init__(self, word_dim, hidden_dim, intra_attention,
                 composition_ln=False, trainable_temperature=False):
        super(BinaryTreeLSTM, self).__init__()
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.intra_attention = intra_attention
        self.treelstm_layer = BinaryTreeLSTMLayer(
            hidden_dim, composition_ln=composition_ln)

        # TODO: Add something to blocks to make this use case more elegant.
        self.comp_query = Linear(
            initializer=HeKaimingInitializer)(
            in_features=hidden_dim,
            out_features=1)
        self.trainable_temperature = trainable_temperature
        if self.trainable_temperature:
            self.temperature_param = nn.Parameter(
                torch.ones(1, 1), requires_grad=True)
Пример #10
0
    def __init__(self, model_dim=None,
                 word_embedding_dim=None,
                 vocab_size=None,
                 initial_embeddings=None,
                 fine_tune_loaded_embeddings=None,
                 num_classes=None,
                 embedding_keep_rate=None,
                 tracking_lstm_hidden_dim=4,
                 transition_weight=None,
                 encode_reverse=None,
                 encode_bidirectional=None,
                 encode_num_layers=None,
                 lateral_tracking=None,
                 tracking_ln=None,
                 use_tracking_in_composition=None,
                 predict_use_cell=None,
                 use_sentence_pair=False,
                 use_difference_feature=False,
                 use_product_feature=False,
                 mlp_dim=None,
                 num_mlp_layers=None,
                 mlp_ln=None,
                 classifier_keep_rate=None,
                 context_args=None,
                 composition_args=None,
                 with_attention=False,
                 data_type=None,
                 target_vocabulary=None,
                 onmt_module=None,
                 **kwargs
                 ):
        super(BaseModel, self).__init__()

        assert not (
            use_tracking_in_composition and not lateral_tracking), "Lateral tracking must be on to use tracking in composition."

        self.use_sentence_pair = use_sentence_pair
        self.use_difference_feature = use_difference_feature
        self.use_product_feature = use_product_feature

        self.hidden_dim = composition_args.size
        self.wrap_items = composition_args.wrap_items
        self.extract_h = composition_args.extract_h

        if data_type == "mt":
            self.post_projection= Linear()(context_args.input_dim, int(context_args.input_dim/2), bias=True)
        self.initial_embeddings = initial_embeddings
        self.word_embedding_dim = word_embedding_dim
        self.model_dim = model_dim
        self.data_type = data_type

        classifier_dropout_rate = 1. - classifier_keep_rate

        vocab = Vocab()
        vocab.size = initial_embeddings.shape[0] if initial_embeddings is not None else vocab_size
        vocab.vectors = initial_embeddings

        # Build parsing component.
        self.spinn = self.build_spinn(
            composition_args, vocab, predict_use_cell)

        # Build classiifer.
        features_dim = self.get_features_dim()
        if data_type != "mt":
            self.mlp = MLP(features_dim, mlp_dim, num_classes,
                       num_mlp_layers, mlp_ln, classifier_dropout_rate)
            #self.generator = nn.Sequential(nn.Linear(self.model_dim, len(self.target_vocabulary), nn.LogSoftmax())

        self.embedding_dropout_rate = 1. - embedding_keep_rate

        # Create dynamic embedding layer.
        self.embed = Embed(
            word_embedding_dim,
            vocab.size,
            vectors=vocab.vectors,
            fine_tune=fine_tune_loaded_embeddings)

        self.input_dim = context_args.input_dim

        self.encode = context_args.encoder
        self.reshape_input = context_args.reshape_input
        self.reshape_context = context_args.reshape_context

        self.inverted_vocabulary = None
Пример #11
0
def init_model(
        FLAGS,
        logger,
        initial_embeddings,
        vocab_size,
        num_classes,
        data_manager,
        logfile_header=None):
    # Choose model.
    logger.Log("Building model.")
    if FLAGS.model_type == "CBOW":
        build_model = spinn.cbow.build_model
    elif FLAGS.model_type == "RNN":
        build_model = spinn.plain_rnn.build_model
    elif FLAGS.model_type == "SPINN":
        build_model = spinn.spinn_core_model.build_model
    elif FLAGS.model_type == "RLSPINN":
        build_model = spinn.rl_spinn.build_model
    elif FLAGS.model_type == "ChoiPyramid":
        build_model = spinn.choi_pyramid.build_model
    elif FLAGS.model_type == "Maillard":
        build_model = spinn.maillard_pyramid.build_model
    elif FLAGS.model_type == "LMS":
        build_model = spinn.lms.build_model
    else:
        raise NotImplementedError

    # Input Encoder.
    context_args = Args()
    if FLAGS.model_type == "LMS":
        intermediate_dim = FLAGS.model_dim * FLAGS.model_dim
    else:
        intermediate_dim = FLAGS.model_dim

    if FLAGS.encode == "projection":
        context_args.reshape_input = lambda x, batch_size, seq_length: x
        context_args.reshape_context = lambda x, batch_size, seq_length: x
        encoder = Linear()(FLAGS.word_embedding_dim, intermediate_dim)
        context_args.input_dim = intermediate_dim
    elif FLAGS.encode == "gru":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = intermediate_dim
        encoder = EncodeGRU(FLAGS.word_embedding_dim, intermediate_dim,
                            num_layers=FLAGS.encode_num_layers,
                            bidirectional=FLAGS.encode_bidirectional,
                            reverse=FLAGS.encode_reverse,
                            mix=(FLAGS.model_type != "CBOW"))
    elif FLAGS.encode == "attn":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = intermediate_dim
        encoder = IntraAttention(FLAGS.word_embedding_dim, intermediate_dim)
    elif FLAGS.encode == "pass":
        context_args.reshape_input = lambda x, batch_size, seq_length: x
        context_args.reshape_context = lambda x, batch_size, seq_length: x
        context_args.input_dim = FLAGS.word_embedding_dim
        def encoder(x): return x
    else:
        raise NotImplementedError

    context_args.encoder = encoder

    # Composition Function.
    composition_args = Args()
    composition_args.lateral_tracking = FLAGS.lateral_tracking
    composition_args.tracking_ln = FLAGS.tracking_ln
    composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition
    composition_args.size = FLAGS.model_dim
    composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim
    composition_args.use_internal_parser = FLAGS.use_internal_parser
    composition_args.transition_weight = FLAGS.transition_weight
    composition_args.wrap_items = lambda x: torch.cat(x, 0)
    composition_args.extract_h = lambda x: x

    if FLAGS.reduce == "treelstm":
        assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.'
        assert FLAGS.model_type != 'LMS', 'Must use reduce=lms for LMS.'
        if FLAGS.model_dim != FLAGS.word_embedding_dim:
            print('If you are setting different hidden layer and word '
                  'embedding sizes, make sure you specify an encoder')
        composition_args.wrap_items = lambda x: bundle(x)
        composition_args.extract_h = lambda x: x.h
        composition_args.extract_c = lambda x: x.c
        composition_args.size = FLAGS.model_dim // 2
        composition = ReduceTreeLSTM(
            FLAGS.model_dim // 2,
            tracker_size=FLAGS.tracking_lstm_hidden_dim,
            use_tracking_in_composition=FLAGS.use_tracking_in_composition,
            composition_ln=FLAGS.composition_ln)
    elif FLAGS.reduce == "tanh":
        class ReduceTanh(nn.Module):
            def forward(self, lefts, rights, tracking=None):
                batch_size = len(lefts)
                ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0))
                return torch.chunk(ret, batch_size, 0)
        composition = ReduceTanh()
    elif FLAGS.reduce == "treegru":
        composition = ReduceTreeGRU(FLAGS.model_dim,
                                    FLAGS.tracking_lstm_hidden_dim,
                                    FLAGS.use_tracking_in_composition)
    elif FLAGS.reduce == "lms":
        composition_args.wrap_items = lambda x: bundle(x)
        composition_args.extract_h = lambda x: x.h
        composition_args.extract_c = lambda x: x.c
        composition_args.size = FLAGS.model_dim
        composition = ReduceTensor(FLAGS.model_dim)
    else:
        raise NotImplementedError

    composition_args.composition = composition

    model = build_model(data_manager, initial_embeddings, vocab_size,
                        num_classes, FLAGS, context_args, composition_args)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug
    model.apply(set_debug)

    # Print model size.
    logger.Log("Architecture: {}".format(model))
    if logfile_header:
        logfile_header.model_architecture = str(model)
    total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0)
                        for w in model.parameters()])
    logger.Log("Total params: {}".format(total_params))
    if logfile_header:
        logfile_header.total_params = int(total_params)

    return model
Пример #12
0
 def __init__(self):
     super(MyModel, self).__init__()
     self.l = Linear(SimpleInitializer, SimpleBiasInitializer)(10,
                                                               10)
Пример #13
0
 def __init__(self, hidden_dim):
     super(Switch, self).__init__()
     self.fc1 = Linear()(in_features=3 * hidden_dim,
                         out_features=hidden_dim,
                         bias=False)
     self.fc2 = Linear()(in_features=hidden_dim, out_features=1, bias=False)
Пример #14
0
 def __init__(self, hidden_dim):
     super(NonCompLayer, self).__init__()
     self.fc = Linear()(2 * hidden_dim, hidden_dim)
Пример #15
0
def init_model(
        FLAGS,
        logger,
        initial_embeddings,
        vocab_size,
        num_classes,
        data_manager,
        logfile_header=None):
    # Choose model.
    logger.Log("Building model.")
    if FLAGS.model_type == "CBOW":
        build_model = spinn.cbow.build_model
    elif FLAGS.model_type == "RNN":
        build_model = spinn.plain_rnn.build_model
    elif FLAGS.model_type == "SPINN":
        build_model = spinn.spinn_core_model.build_model
    elif FLAGS.model_type == "RLSPINN":
        build_model = spinn.rl_spinn.build_model
    elif FLAGS.model_type == "Pyramid":
        build_model = spinn.pyramid.build_model
    else:
        raise NotImplementedError

    # Input Encoder.
    context_args = Args()
    context_args.reshape_input = lambda x, batch_size, seq_length: x
    context_args.reshape_context = lambda x, batch_size, seq_length: x
    context_args.input_dim = FLAGS.word_embedding_dim

    if FLAGS.encode == "projection":
        encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim)
    elif FLAGS.encode == "gru":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim,
                            num_layers=FLAGS.encode_num_layers,
                            bidirectional=FLAGS.encode_bidirectional,
                            reverse=FLAGS.encode_reverse)
    elif FLAGS.encode == "attn":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim)
    elif FLAGS.encode == "pass":
        def encoder(x): return x
    else:
        raise NotImplementedError

    context_args.encoder = encoder

    # Composition Function.
    composition_args = Args()
    composition_args.lateral_tracking = FLAGS.lateral_tracking
    composition_args.tracking_ln = FLAGS.tracking_ln
    composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition
    composition_args.size = FLAGS.model_dim
    composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim
    composition_args.use_internal_parser = FLAGS.use_internal_parser
    composition_args.transition_weight = FLAGS.transition_weight
    composition_args.wrap_items = lambda x: torch.cat(x, 0)
    composition_args.extract_h = lambda x: x
    composition_args.extract_c = None

    composition_args.detach = FLAGS.transition_detach
    composition_args.evolution = FLAGS.evolution

    if FLAGS.reduce == "treelstm":
        assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.'
        if FLAGS.model_dim != FLAGS.word_embedding_dim:
            print('If you are setting different hidden layer and word '
                  'embedding sizes, make sure you specify an encoder')
        composition_args.wrap_items = lambda x: bundle(x)
        composition_args.extract_h = lambda x: x.h
        composition_args.extract_c = lambda x: x.c
        composition_args.size = FLAGS.model_dim / 2
        composition = ReduceTreeLSTM(FLAGS.model_dim / 2,
                                     tracker_size=FLAGS.tracking_lstm_hidden_dim,
                                     use_tracking_in_composition=FLAGS.use_tracking_in_composition,
                                     composition_ln=FLAGS.composition_ln)
    elif FLAGS.reduce == "tanh":
        class ReduceTanh(nn.Module):
            def forward(self, lefts, rights, tracking=None):
                batch_size = len(lefts)
                ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0))
                return torch.chunk(ret, batch_size, 0)
        composition = ReduceTanh()
    elif FLAGS.reduce == "treegru":
        composition = ReduceTreeGRU(FLAGS.model_dim,
                                    FLAGS.tracking_lstm_hidden_dim,
                                    FLAGS.use_tracking_in_composition)
    else:
        raise NotImplementedError

    composition_args.composition = composition

    model = build_model(data_manager, initial_embeddings, vocab_size,
                        num_classes, FLAGS, context_args, composition_args)

    # Build optimizer.
    if FLAGS.optimizer_type == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate,
                               betas=(0.9, 0.999), eps=1e-08)
    elif FLAGS.optimizer_type == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08)
    else:
        raise NotImplementedError

    # Build trainer.
    if FLAGS.evolution:
        trainer = ModelTrainer_ES(model, optimizer)
    else:
        trainer = ModelTrainer(model, optimizer)

    # Print model size.
    logger.Log("Architecture: {}".format(model))
    if logfile_header:
        logfile_header.model_architecture = str(model)
    total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()])
    logger.Log("Total params: {}".format(total_params))
    if logfile_header:
        logfile_header.total_params = int(total_params)

    return model, optimizer, trainer
Пример #16
0
    def __init__(self,
                 model_dim=None,
                 model_type=None,
                 word_embedding_dim=None,
                 vocab_size=None,
                 initial_embeddings=None,
                 fine_tune_loaded_embeddings=None,
                 num_classes=None,
                 embedding_keep_rate=None,
                 tracking_lstm_hidden_dim=4,
                 transition_weight=None,
                 encode_reverse=None,
                 encode_bidirectional=None,
                 encode_num_layers=None,
                 lateral_tracking=None,
                 tracking_ln=None,
                 use_tracking_in_composition=None,
                 predict_use_cell=None,
                 use_sentence_pair=False,
                 use_difference_feature=False,
                 use_product_feature=False,
                 mlp_dim=None,
                 num_mlp_layers=None,
                 mlp_ln=None,
                 classifier_keep_rate=None,
                 context_args=None,
                 composition_args=None,
                 with_attention=False,
                 data_type=None,
                 target_vocabulary=None,
                 onmt_module=None,
                 FLAGS=None,
                 data_manager=None,
                 **kwargs):
        super(NMTModel, self).__init__()

        assert not (
            use_tracking_in_composition and not lateral_tracking
        ), "Lateral tracking must be on to use tracking in composition."

        self.kwargs = kwargs

        self.model_dim = model_dim
        self.model_type = model_type
        self.data_type = data_type
        self.target_vocabulary = target_vocabulary

        if self.model_type == "SPINN":
            encoder_builder = spinn_builder
        elif self.model_type == "RLSPINN":
            encoder_builder = rl_builder
        elif self.model_type == "LMS":
            encoder_builder = lms_builder
        elif self.model_type == "RNN":
            encoder_builder = rnn_builder

        if self.model_type == "SPINN" or "RNN" or "LMS":
            self.encoder = encoder_builder(
                model_dim=model_dim,
                word_embedding_dim=word_embedding_dim,
                vocab_size=vocab_size,
                initial_embeddings=initial_embeddings,
                fine_tune_loaded_embeddings=fine_tune_loaded_embeddings,
                num_classes=num_classes,
                embedding_keep_rate=embedding_keep_rate,
                tracking_lstm_hidden_dim=tracking_lstm_hidden_dim,
                transition_weight=transition_weight,
                use_sentence_pair=use_sentence_pair,
                lateral_tracking=lateral_tracking,
                tracking_ln=tracking_ln,
                use_tracking_in_composition=use_tracking_in_composition,
                predict_use_cell=predict_use_cell,
                use_difference_feature=use_difference_feature,
                use_product_feature=use_product_feature,
                classifier_keep_rate=classifier_keep_rate,
                mlp_dim=mlp_dim,
                num_mlp_layers=num_mlp_layers,
                mlp_ln=mlp_ln,
                context_args=context_args,
                composition_args=composition_args,
                with_attention=with_attention,
                data_type=data_type,
                onmt_module=onmt_module,
                FLAGS=FLAGS,
                data_manager=data_manager)
        else:
            self.encoder = rl_builder(data_manager=data_manager,
                                      initial_embeddings=initial_embeddings,
                                      vocab_size=vocab_size,
                                      num_classes=num_classes,
                                      FLAGS=FLAGS,
                                      context_args=context_args,
                                      composition_args=composition_args)
        if self.model_type == "LMS":
            self.model_dim **= 2
        # To-do: move this head of script. onmt_module path needs to be imported to do so.
        sys.path.append(onmt_module)
        from onmt.decoders.decoder import InputFeedRNNDecoder, StdRNNDecoder, RNNDecoderBase
        from onmt.encoders.rnn_encoder import RNNEncoder
        from onmt.modules import Embeddings

        self.output_embeddings = Embeddings(self.model_dim,
                                            len(target_vocabulary) + 1, 0)

        # Below, model_dim is multiplied by 2 so that the output dimension is the same as the
        # input word embedding dimension, and not half.
        # Look at TreeRNN for details (there is a down projection).
        if self.model_type == "RNN":
            self.is_bidirectional = True
            self.down_project = Linear()(2 * self.model_dim,
                                         self.model_dim,
                                         bias=True)
            self.down_project_context = Linear()(2 * self.model_dim,
                                                 self.model_dim,
                                                 bias=True)
        else:
            if self.model_type == "LMS":
                self.spinn = self.encoder.lms
            else:
                self.spinn = self.encoder.spinn
            self.is_bidirectional = False

        self.decoder = StdRNNDecoder("LSTM",
                                     self.is_bidirectional,
                                     1,
                                     self.model_dim,
                                     embeddings=self.output_embeddings)
        self.generator = nn.Sequential(
            nn.Linear(self.model_dim,
                      len(self.target_vocabulary) + 1), nn.LogSoftmax())
Пример #17
0
    def __init__(self,
                 model_dim=None,
                 word_embedding_dim=None,
                 vocab_size=None,
                 use_product_feature=None,
                 use_difference_feature=None,
                 initial_embeddings=None,
                 num_classes=None,
                 embedding_keep_rate=None,
                 use_sentence_pair=False,
                 classifier_keep_rate=None,
                 mlp_dim=None,
                 num_mlp_layers=None,
                 mlp_ln=None,
                 composition_ln=None,
                 context_args=None,
                 trainable_temperature=None,
                 test_temperature_multiplier=None,
                 selection_dim=None,
                 gumbel=None,
                 **kwargs):
        super(Pyramid, self).__init__()

        self.use_sentence_pair = use_sentence_pair
        self.use_difference_feature = use_difference_feature
        self.use_product_feature = use_product_feature
        self.model_dim = model_dim
        self.test_temperature_multiplier = test_temperature_multiplier
        self.trainable_temperature = trainable_temperature
        self.gumbel = gumbel
        self.selection_dim = selection_dim

        self.classifier_dropout_rate = 1. - classifier_keep_rate
        self.embedding_dropout_rate = 1. - embedding_keep_rate

        vocab = Vocab()
        vocab.size = initial_embeddings.shape[
            0] if initial_embeddings is not None else vocab_size
        vocab.vectors = initial_embeddings

        self.embed = Embed(word_embedding_dim,
                           vocab.size,
                           vectors=vocab.vectors)

        self.composition_fn = SimpleTreeLSTM(model_dim / 2,
                                             composition_ln=composition_ln)
        self.selection_fn_1 = Linear(initializer=HeKaimingInitializer)(
            model_dim, selection_dim)
        self.selection_fn_2 = Linear(initializer=HeKaimingInitializer)(
            selection_dim, 1)

        def selection_fn(selection_input):
            selection_hidden = F.tanh(self.selection_fn_1(selection_input))
            return self.selection_fn_2(selection_hidden)

        self.selection_fn = selection_fn

        mlp_input_dim = self.get_features_dim()

        self.mlp = MLP(mlp_input_dim, mlp_dim, num_classes, num_mlp_layers,
                       mlp_ln, self.classifier_dropout_rate)

        if self.trainable_temperature:
            self.temperature = nn.Parameter(torch.ones(1, 1),
                                            requires_grad=True)

        self.encode = context_args.encoder
        self.reshape_input = context_args.reshape_input
        self.reshape_context = context_args.reshape_context

        # For sample printing and logging
        self.merge_sequence_memory = None
        self.inverted_vocabulary = None
        self.temperature_to_display = 0.0