コード例 #1
0
def batch_make_att_masks(node_list, tree_decoder=None, walker=None, dtype=None):
    """
    tbd
    """
    if dtype is None:
        dtype=np.byte
    if walker is None:
        walker = OnehotBuilder()
    if tree_decoder is None:
        tree_decoder = create_tree_decoder()

    true_binary = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype)
    rule_masks = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype)

    for i in range(len(node_list)):
        node = node_list[i]
        tree_decoder.decode(node, walker)

        true_binary[i, np.arange(walker.num_steps), walker.global_rule_used[:walker.num_steps]] = 1
        true_binary[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1

        for j in range(walker.num_steps):
            rule_masks[i, j, walker.mask_list[j]] = 1

        rule_masks[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1.0

    return true_binary, rule_masks
コード例 #2
0
    def __init__(self, *args, **kwargs):
        # get model config
        model_config = json.load(open(cmd_args.model_config, 'r'))
        self.ae = MolVAE(model_config)

        # load model weights
        model_weights = paddle.load(cmd_args.saved_model)
        self.ae.set_state_dict(model_weights)

        self.onehot_walker = OnehotBuilder()
        self.tree_decoder = create_tree_decoder()
        self.grammar = parser.Grammar(cmd_args.grammar_file)
コード例 #3
0
def process_chunk(smiles_list):
    grammar = parser.Grammar(cmd_args.grammar_file)

    cfg_tree_list = []
    for smiles in smiles_list:
        ts = parser.parse(smiles, grammar)
        assert isinstance(ts, list) and len(ts) == 1

        n = AnnotatedTree2MolTree(ts[0])
        cfg_tree_list.append(n)

    walker = OnehotBuilder()
    tree_decoder = create_tree_decoder()
    onehot, masks = batch_make_att_masks(cfg_tree_list, tree_decoder, walker, dtype=np.byte)

    return (onehot, masks)
コード例 #4
0
    def __init__(self, *args, **kwargs):
        if cmd_args.ae_type == 'vae':
            self.ae = MolVAE()
        elif cmd_args.ae_type == 'autoenc':
            self.ae = MolAutoEncoder()
        else:
            raise Exception('unknown ae type %s' % cmd_args.ae_type)
        if cmd_args.mode == 'gpu':
            self.ae = self.ae.cuda()

        assert cmd_args.saved_model is not None

	if cmd_args.mode == 'cpu':
		self.ae.load_state_dict(torch.load(cmd_args.saved_model, map_location=lambda storage, loc: storage))
        else:
        	self.ae.load_state_dict(torch.load(cmd_args.saved_model))

        self.onehot_walker = OnehotBuilder()
        self.tree_decoder = create_tree_decoder()
        self.grammar = parser.Grammar(cmd_args.grammar_file)