Пример #1
0
    def test_graph_environment_step(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()
        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
        batch_size = 2
        env = GraphEnvironment(mask_gen,
                               reward_fun=lambda x: np.zeros(len(x)),
                               batch_size=2)
        graphs, node_mask, full_logit_priors = env.reset()
        while True:
            try:
                next_node = np.argmax(node_mask, axis=1)
                next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
                next_action = (next_node, next_action_)
                (graphs, node_mask, full_logit_priors), reward, done, info = env.step(next_action)
            except StopIteration:
                break

        print(info)
Пример #2
0
    def test_decoder_with_environment_new(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()
        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
        batch_size = 2

        env = GraphEnvironment(mask_gen,
                               reward_fun=lambda x: 2*np.ones(len(x)),
                               batch_size=2)

        def dummy_stepper(state):
            graphs, node_mask, full_logit_priors = state
            next_node = np.argmax(node_mask, axis=1)
            next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
            next_action = (next_node, next_action_)
            return next_action, np.zeros(len(state))

        dummy_stepper.output_shape = [None, None, None]
        dummy_stepper.init_encoder_output = lambda x: None

        decoder = DecoderWithEnvironmentNew(dummy_stepper, env)
        out = decoder()
        print('done!')
Пример #3
0
def get_node_decoder(grammar,
                     max_seq_length=15,
                     drop_rate=0.0,
                     decoder_type='attn',
                     rule_policy=None,
                     reward_fun=lambda x: -1 * np.ones(len(x)),
                     batch_size=None,
                     priors='conditional',
                     bins=10):

    codec = get_codec(True, grammar, max_seq_length)
    assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
    if 'attn' in decoder_type:
        model_type = 'transformer'
    elif 'rnn' in decoder_type:
        model_type = 'rnn'

    if 'distr' in decoder_type:
        if 'softmax' in decoder_type:
            output_type = 'distributions_softmax'
        else:
            output_type = 'distributions_thompson'
    else:
        output_type = 'values'

    model = get_graph_model(codec,
                            drop_rate,
                            model_type,
                            output_type,
                            num_bins=bins)
    # encoder = GraphEncoder(grammar=codec.grammar,
    #                        d_model=512,
    #                        drop_rate=drop_rate,
    #                        model_type=model_type)
    #
    #
    # model = MultipleOutputHead(model=encoder,
    #                            output_spec={'node': 1,  # to be used to select next node to expand
    #                                         'action': codec.feature_len()},  # to select the action for chosen node
    #                            drop_rate=drop_rate)

    # don't support using this model in VAE-style models yet

    mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                       grammar=codec.grammar)
    mask_gen.priors = priors
    if rule_policy is None:
        rule_policy = SoftmaxRandomSamplePolicy()

    stepper = GraphDecoderWithNodeSelection(model, rule_policy=rule_policy)
    env = GraphEnvironment(mask_gen,
                           reward_fun=reward_fun,
                           batch_size=batch_size)
    decoder = DecoderWithEnvironmentNew(stepper, env)

    return decoder, stepper
Пример #4
0
def get_node_decoder(grammar,
                     max_seq_length=15,
                     drop_rate=0.0,
                     decoder_type='attn',
                     rule_policy=None,
                     reward_fun=lambda x: -1 * np.ones(len(x)),
                     batch_size=None,
                     priors='conditional',
                     bins=10):

    codec = get_codec(True, grammar, max_seq_length)
    assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
    if 'attn' in decoder_type:
        model_type = 'transformer'
    elif 'rnn' in decoder_type:
        model_type = 'rnn'
    elif 'conditional' in decoder_type:
        if 'sparse' in decoder_type:
            model_type = 'conditional_sparse'
        else:
            model_type = 'conditional'

    if 'distr' in decoder_type:
        if 'softmax' in decoder_type:
            output_type = 'distributions_softmax'
        else:
            output_type = 'distributions_thompson'
    else:
        output_type = 'values'

    model = get_graph_model(codec,
                            drop_rate,
                            model_type,
                            output_type,
                            num_bins=bins)

    if model_type == 'conditional_sparse':
        priors = 'term_dist_only'

    mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                       grammar=codec.grammar,
                                       priors=priors)
    mask_gen.priors = priors
    if rule_policy is None:
        rule_policy = SoftmaxRandomSamplePolicySparse(
        ) if 'sparse' in decoder_type else SoftmaxRandomSamplePolicy()

    stepper_type = GraphDecoderWithNodeSelectionSparse if 'sparse' in decoder_type else GraphDecoderWithNodeSelection
    stepper = stepper_type(model, rule_policy=rule_policy)
    env = GraphEnvironment(mask_gen,
                           reward_fun=reward_fun,
                           batch_size=batch_size)
    decoder = DecoderWithEnvironmentNew(stepper, env)

    return decoder, stepper
Пример #5
0
def get_decoder(
        molecules=True,
        grammar=True,
        z_size=200,
        decoder_hidden_n=200,
        feature_len=12,  # TODO: remove this
        max_seq_length=15,
        drop_rate=0.0,
        decoder_type='step',
        task=None,
        node_policy=None,
        rule_policy=None,
        reward_fun=lambda x: -1 * np.ones(len(x)),
        batch_size=None,
        priors=True):
    codec = get_codec(molecules, grammar, max_seq_length)

    if decoder_type == 'old':
        stepper = ResettingRNNDecoder(z_size=z_size,
                                      hidden_n=decoder_hidden_n,
                                      feature_len=codec.feature_len(),
                                      max_seq_length=max_seq_length,
                                      steps=max_seq_length,
                                      drop_rate=drop_rate)
        stepper = OneStepDecoderContinuous(stepper)
    elif 'graph' in decoder_type and decoder_type not in [
            'attn_graph', 'rnn_graph'
    ]:
        return get_node_decoder(grammar, max_seq_length, drop_rate,
                                decoder_type, rule_policy, reward_fun,
                                batch_size, priors)

    elif decoder_type in ['attn_graph', 'rnn_graph']:  # deprecated
        assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
        if 'attn' in decoder_type:
            encoder = GraphEncoder(grammar=codec.grammar,
                                   d_model=512,
                                   drop_rate=drop_rate,
                                   model_type='transformer')
        elif 'rnn' in decoder_type:
            encoder = GraphEncoder(grammar=codec.grammar,
                                   d_model=512,
                                   drop_rate=drop_rate,
                                   model_type='rnn')

        model = MultipleOutputHead(
            model=encoder,
            output_spec={
                'node': 1,  # to be used to select next node to expand
                'action': codec.feature_len()
            },  # to select the action for chosen node
            drop_rate=drop_rate)

        # don't support using this model in VAE-style models yet
        model.init_encoder_output = lambda x: None
        mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                           grammar=codec.grammar)
        mask_gen.priors = priors
        # bias=codec.grammar.get_log_frequencies())
        if node_policy is None:
            node_policy = SoftmaxRandomSamplePolicy()
        if rule_policy is None:
            rule_policy = SoftmaxRandomSamplePolicy()
        if 'node' in decoder_type:
            stepper = GraphDecoderWithNodeSelection(model,
                                                    node_policy=node_policy,
                                                    rule_policy=rule_policy)
            env = GraphEnvironment(mask_gen,
                                   reward_fun=reward_fun,
                                   batch_size=batch_size)
            decoder = DecoderWithEnvironmentNew(stepper, env)
        else:

            stepper = GraphDecoder(model=model, mask_gen=mask_gen)
            decoder = to_gpu(
                SimpleDiscreteDecoderWithEnv(stepper,
                                             rule_policy,
                                             task=task,
                                             batch_size=batch_size))
        return decoder, stepper

    else:
        if decoder_type == 'step':
            stepper = SimpleRNNDecoder(z_size=z_size,
                                       hidden_n=decoder_hidden_n,
                                       feature_len=codec.feature_len(),
                                       max_seq_length=max_seq_length,
                                       drop_rate=drop_rate,
                                       use_last_action=False)

        elif decoder_type == 'action':
            stepper = SimpleRNNDecoder(
                z_size=z_size,  # + feature_len,
                hidden_n=decoder_hidden_n,
                feature_len=codec.feature_len(),
                max_seq_length=max_seq_length,
                drop_rate=drop_rate,
                use_last_action=True)

        elif decoder_type == 'action_resnet':
            stepper = ResNetRNNDecoder(
                z_size=z_size,  # + feature_len,
                hidden_n=decoder_hidden_n,
                feature_len=codec.feature_len(),
                max_seq_length=max_seq_length,
                drop_rate=drop_rate,
                use_last_action=True)

        elif decoder_type == 'attention':
            stepper = SelfAttentionDecoderStep(num_actions=codec.feature_len(),
                                               max_seq_len=max_seq_length,
                                               drop_rate=drop_rate,
                                               enc_output_size=z_size)
        elif decoder_type == 'random':
            stepper = RandomDecoder(feature_len=codec.feature_len(),
                                    max_seq_length=max_seq_length)
        else:
            raise NotImplementedError('Unknown decoder type: ' +
                                      str(decoder_type))

    if grammar is not False and '_graph' not in decoder_type:
        # add a masking layer
        mask_gen = get_codec(molecules, grammar, max_seq_length).mask_gen
        stepper = MaskingHead(stepper, mask_gen)

    policy = SoftmaxRandomSamplePolicy(
    )  # bias=codec.grammar.get_log_frequencies())

    decoder = to_gpu(
        SimpleDiscreteDecoderWithEnv(
            stepper, policy, task=task,
            batch_size=batch_size))  # , bypass_actions=True))

    return decoder, stepper
Пример #6
0
def make_environment(grammar, batch_size=2):
    mask_gen = HypergraphMaskGenerator(30, grammar, priors='conditional')
    env = GraphEnvironment(mask_gen,
                           reward_fun=lambda x: 2 * np.ones(len(x)),
                           batch_size=batch_size)
    return env