Esempio n. 1
0
def _decode_step(X, A, NX, NA, last_action, finished,
                 get_init, get_action,
                 random=True, n_node_types=get_mol_spec().num_atom_types,
                 n_edge_types=get_mol_spec().num_bond_types):
    if X is None:
        init = get_init()

        if random:
            X = []
            for i in range(init.shape[0]):
                p = init[i, :]
                selected_atom = np.random.choice(np.arange(init.shape[1]), 1, p=p)[0]
                X.append(selected_atom)
            X = np.array(X, dtype=np.int32)
        else:
            X = np.argmax(init, axis=1)
        A = np.zeros((0, 3), dtype=np.int32)
        NX = last_action = np.ones([X.shape[0]], dtype=np.int32)
        NA = np.zeros([X.shape[0]], dtype=np.int32)
        finished = np.array([False, ] * X.shape[0], dtype=np.bool)

        return X, A, NX, NA, last_action, finished
    else:
        X_u = X[np.repeat(np.logical_not(finished), NX)]
        A_u = A[np.repeat(np.logical_not(finished), NA), :]
        NX_u = NX[np.logical_not(finished)]
        NA_u = NA[np.logical_not(finished)]
        last_action_u = last_action[np.logical_not(finished)]

        # conv
        mol_ids_rep = NX_rep = np.repeat(np.arange(NX_u.shape[0]), NX_u)
        rep_ids_rep = np.zeros_like(mol_ids_rep)

        if A.shape[0] == 0:
            D_2 = D_3 = np.zeros((0, 2), dtype=np.int32)
            A_u = [np.zeros((0, 2), dtype=np.int32) for _ in range(get_mol_spec().num_bond_types)]
            A_u += [D_2, D_3]
        else:
            cumsum = np.cumsum(np.pad(NX_u, [[1, 0]], mode='constant')[:-1])
            shift = np.repeat(cumsum, NA_u)
            A_u[:, :2] += np.stack([shift, ] * 2, axis=1)
            D_2, D_3 = data.get_d(A_u, X_u)
            A_u = [A_u[A_u[:, 2] == _i, :2] for _i in range(n_edge_types)]
            A_u += [D_2, D_3]

        mask = np.zeros([X_u.shape[0]], dtype=np.int32)
        last_append_index = np.cumsum(NX_u) - 1
        mask[last_append_index] = np.where(last_action_u == 1,
                                           np.ones_like(last_append_index, dtype=np.int32),
                                           np.ones_like(last_append_index, dtype=np.int32) * 2)

        decode_input = [X_u, A_u, NX_u, NX_rep, mask, mol_ids_rep, rep_ids_rep]
        append, connect, end = get_action(decode_input)

        if A.shape[0] == 0:
            max_index = np.argmax(np.reshape(append, [-1, n_node_types * n_edge_types]), axis=1)
            atom_type, bond_type = np.unravel_index(max_index, [n_node_types, n_edge_types])
            X = np.reshape(np.stack([X, atom_type], axis=1), [-1])
            NX = np.array([2, ] * len(finished), dtype=np.int32)
            A = np.stack([np.zeros([len(finished), ], dtype=np.int32),
                          np.ones([len(finished), ], dtype=np.int32),
                          bond_type], axis=1)
            NA = np.ones([len(finished), ], dtype=np.int32)
            last_action = np.ones_like(NX, dtype=np.int32)

        else:
            # process for each molecule
            append, connect = np.split(append, np.cumsum(NX_u)), np.split(connect, np.cumsum(NX_u))
            end = end.tolist()

            unfinished_ids = np.where(np.logical_not(finished))[0].tolist()
            cumsum = np.cumsum(NX)
            cumsum_a = np.cumsum(NA)

            X_insert = []
            X_insert_ids = []
            A_insert = []
            A_insert_ids = []
            finished_ids = []

            for i, (unfinished_id, append_i, connect_i, end_i) \
                    in enumerate(zip(unfinished_ids, append, connect, end)):
                if random:
                    def _rand_id(*_x):
                        _x_reshaped = [np.reshape(_xi, [-1]) for _xi in _x]
                        _x_length = np.array([_x_reshape_i.shape[0] for _x_reshape_i in _x_reshaped],
                                             dtype=np.int32)
                        _begin = np.cumsum(np.pad(_x_length, [[1, 0]], mode='constant')[:-1])
                        _end = np.cumsum(_x_length) - 1
                        _p = np.concatenate(_x_reshaped)
                        _p = _p / np.sum(_p)
                        _rand_index = np.random.choice(np.arange(_p.shape[0]), 1, p=_p)[0]
                        _p_step = _p[_rand_index]
                        _x_index = np.where(np.logical_and(_begin <= _rand_index, _end >= _rand_index))[0][0]
                        _rand_index = _rand_index - _begin[_x_index]
                        _rand_index = np.unravel_index(_rand_index, _x[_x_index].shape)
                        return _x_index, _rand_index, _p_step

                    action_type, action_index, p_step = _rand_id(append_i, connect_i, np.array([end_i]))
                else:
                    _argmax = lambda _x: np.unravel_index(np.argmax(_x), _x.shape)
                    append_id, append_val = _argmax(append_i), np.max(append_i)
                    connect_id, connect_val = _argmax(connect_i), np.max(connect_i)
                    end_val = end_i
                    if end_val >= append_val and end_val >= connect_val:
                        action_type = 2
                        action_index = None
                    elif append_val >= connect_val and append_val >= end_val:
                        action_type = 0
                        action_index = append_id
                    else:
                        action_type = 1
                        action_index = connect_id
                if action_type == 2:
                    # finish growth
                    finished_ids.append(unfinished_id)
                elif action_type == 0:
                    # append action
                    append_pos, atom_type, bond_type = action_index
                    X_insert.append(atom_type)
                    X_insert_ids.append(unfinished_id)
                    A_insert.append([append_pos, NX[unfinished_id], bond_type])
                    A_insert_ids.append(unfinished_id)
                else:
                    # connect
                    connect_ps, bond_type = action_index
                    A_insert.append([NX[unfinished_id] - 1, connect_ps, bond_type])
                    A_insert_ids.append(unfinished_id)
            if len(A_insert_ids) > 0:
                A = np.insert(A, cumsum_a[A_insert_ids], A_insert, axis=0)
                NA[A_insert_ids] += 1
                last_action[A_insert_ids] = 0
            if len(X_insert_ids) > 0:
                X = np.insert(X, cumsum[X_insert_ids], X_insert, axis=0)
                NX[X_insert_ids] += 1
                last_action[X_insert_ids] = 1
            if len(finished_ids) > 0:
                finished[finished_ids] = True
            # print finished

        return X, A, NX, NA, last_action, finished
Esempio n. 2
0
 def _get_model(configs):
     return models.CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs)
Esempio n. 3
0
def _engine_cond(cond_type='scaffold',
                 file_name='datasets/ChEMBL_scaffold.txt',
                 num_scaffolds=734,
                 is_full=False,
                 ckpt_dir='ckpt/scaffold',
                 num_folds=5,
                 fold_id=0,
                 batch_size=50,
                 batch_size_test=100,
                 num_workers=2,
                 k=5,
                 p=0.8,
                 F_e=16,
                 F_h=(32, 64, 128, 128, 256, 256),
                 F_skip=256,
                 F_c=(512, ),
                 Fh_policy=128,
                 activation='relu',
                 N_rnn=3,
                 gpu_ids=(0, 1, 2, 3),
                 lr=1e-3,
                 decay=0.015,
                 decay_step=100,
                 clip_grad=3.0,
                 iterations=30000,
                 summary_step=200):
    if all([
            os.path.isfile(os.path.join(ckpt_dir, _n))
            for _n in ['log.out', 'ckpt.params', 'trainer.status']
    ]):
        is_continuous = True
    else:
        is_continuous = False

    if is_full:
        if cond_type != 'kinase':
            if cond_type == 'scaffold':
                cond = data.SparseFP(num_scaffolds)
                N_C = num_scaffolds
            elif cond_type == 'prop':
                cond = data.Delimited()
                N_C = 2
            else:
                raise ValueError

            with open(file_name) as f:
                dataset = data.Lambda(f.readlines(),
                                      lambda _x: _x.strip('\n').strip('\r'))

            # get sampler and loader for training set
            sampler_train = data.BalancedSampler(
                cost=[len(l.split('\t')[0]) for l in dataset],
                batch_size=batch_size)
            loader_train = data.CMolRNNLoader(dataset,
                                              batch_sampler=sampler_train,
                                              num_workers=num_workers,
                                              k=k,
                                              p=p,
                                              conditional=cond)

            loader_test = []
        else:
            cond = data.Delimited()
            N_C = 2

            if all([
                    os.path.isfile(os.path.join(ckpt_dir, _n))
                    for _n in ['log.out', 'ckpt.params', 'trainer.status']
            ]):
                is_continuous = True
            else:
                is_continuous = False

            with open(file_name) as f:
                dataset = data.Lambda(f.readlines(),
                                      lambda _x: _x.strip('\n').strip('\r'))

            # get dataset
            def _filter(_line, _i):
                return int(_line.split('\t')[-1]) == _i

            db_train = data.Lambda(data.Filter(
                dataset, fn=lambda _x: not _filter(_x, fold_id)),
                                   fn=lambda _x: _x[:-2])
            db_test = data.Lambda(data.Filter(
                dataset, fn=lambda _x: _filter(_x, fold_id)),
                                  fn=lambda _x: _x[:-2])

            # get sampler and loader for test set
            loader_test = data.CMolRNNLoader(db_test,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             k=k,
                                             p=p,
                                             conditional=cond,
                                             batch_size=batch_size_test)

            # get sampler and loader for training set
            loader_train = data.CMolRNNLoader(db_train,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              k=k,
                                              p=p,
                                              conditional=cond,
                                              batch_size=batch_size)

        # get iterator
        it_train, it_test = iter(loader_train), iter(loader_test)
    else:
        if cond_type != 'kinase':
            if cond_type == 'scaffold':
                cond = data.SparseFP(num_scaffolds)
                N_C = num_scaffolds
            elif cond_type == 'prop':
                cond = data.Delimited()
                N_C = 2
            else:
                raise ValueError

            if all([
                    os.path.isfile(os.path.join(ckpt_dir, _n))
                    for _n in ['log.out', 'ckpt.params', 'trainer.status']
            ]):
                is_continuous = True
            else:
                is_continuous = False

            with open(file_name) as f:
                dataset = data.Lambda(f.readlines(),
                                      lambda _x: _x.strip('\n').strip('\r'))

            # get dataset
            db_train = data.KFold(dataset,
                                  k=num_folds,
                                  fold_id=fold_id,
                                  is_train=True)
            db_test = data.KFold(dataset,
                                 k=num_folds,
                                 fold_id=fold_id,
                                 is_train=False)

            # get sampler and loader for training set
            sampler_train = data.BalancedSampler(
                cost=[len(l.split('\t')[0]) for l in db_train],
                batch_size=batch_size)
            loader_train = data.CMolRNNLoader(db_train,
                                              batch_sampler=sampler_train,
                                              num_workers=num_workers,
                                              k=k,
                                              p=p,
                                              conditional=cond)

            # get sampler and loader for test set
            sampler_test = data.BalancedSampler(
                cost=[len(l.split('\t'[0])) for l in db_test],
                batch_size=batch_size_test)
            loader_test = data.CMolRNNLoader(db_test,
                                             batch_sampler=sampler_test,
                                             num_workers=num_workers,
                                             k=k,
                                             p=p,
                                             conditional=cond)

        else:
            cond = data.Delimited()
            N_C = 2

            if all([
                    os.path.isfile(os.path.join(ckpt_dir, _n))
                    for _n in ['log.out', 'ckpt.params', 'trainer.status']
            ]):
                is_continuous = True
            else:
                is_continuous = False

            with open(file_name) as f:
                dataset = data.Lambda(f.readlines(),
                                      lambda _x: _x.strip('\n').strip('\r'))

            # get dataset
            def _filter(_line, _i):
                return int(_line.split('\t')[-1]) == _i

            db_train = data.Lambda(data.Filter(
                dataset, fn=lambda _x: not _filter(_x, fold_id)),
                                   fn=lambda _x: _x[:-2])
            db_test = data.Lambda(data.Filter(
                dataset, fn=lambda _x: _filter(_x, fold_id)),
                                  fn=lambda _x: _x[:-2])

            # get sampler and loader for training set
            loader_train = data.CMolRNNLoader(db_train,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              k=k,
                                              p=p,
                                              conditional=cond,
                                              batch_size=batch_size)

            # get sampler and loader for test set
            loader_test = data.CMolRNNLoader(db_test,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             k=k,
                                             p=p,
                                             conditional=cond,
                                             batch_size=batch_size_test)

        # get iterator
        it_train, it_test = iter(loader_train), iter(loader_test)

    # build model
    if not is_continuous:
        configs = {
            'N_C': N_C,
            'F_e': F_e,
            'F_h': F_h,
            'F_skip': F_skip,
            'F_c': F_c,
            'Fh_policy': Fh_policy,
            'activation': activation,
            'rename': True,
            'N_rnn': N_rnn
        }
        with open(os.path.join(ckpt_dir, 'configs.json'), 'w') as f:
            json.dump(configs, f)
    else:
        with open(os.path.join(ckpt_dir, 'configs.json')) as f:
            configs = json.load(f)

    model = models.CVanillaMolGen_RNN(get_mol_spec().num_atom_types,
                                      get_mol_spec().num_bond_types,
                                      D=2,
                                      **configs)

    ctx = [mx.gpu(i) for i in gpu_ids]
    model.collect_params().initialize(mx.init.Xavier(),
                                      force_reinit=True,
                                      ctx=ctx)
    if not is_continuous:
        if cond_type == 'kinase':
            model.load_params(os.path.join(ckpt_dir, 'ckpt.params.bk'),
                              ctx=ctx,
                              allow_missing=True)
    else:
        model.load_params(os.path.join(ckpt_dir, 'ckpt.params'), ctx=ctx)

    # construct optimizer
    opt = mx.optimizer.Adam(learning_rate=lr, clip_gradient=clip_grad)
    trainer = gluon.Trainer(model.collect_params(), opt)
    if is_continuous:
        trainer.load_states(os.path.join(ckpt_dir, 'trainer.status'))

    if not is_continuous:
        t0 = time.time()
        global_counter = 0
    else:
        with open(os.path.join(ckpt_dir, 'log.out')) as f:
            records = f.readlines()
            if records[-1] != 'Training finished\n':
                final_record = records[-1]
            else:
                final_record = records[-2]
        count, t_final = int(final_record.split('\t')[0]), float(
            final_record.split('\t')[1])
        t0 = time.time() - t_final * 60
        global_counter = count

    with open(os.path.join(ckpt_dir, 'log.out'),
              mode='w' if not is_continuous else 'a') as f:
        if not is_continuous:
            f.write('step\ttime(h)\tloss\tlr\n')
        while True:
            global_counter += 1

            try:
                inputs = [next(it_train) for _ in range(len(gpu_ids))]
            except StopIteration:
                it_train = iter(loader_train)
                inputs = [next(it_train) for _ in range(len(gpu_ids))]

            # move to gpu
            inputs = [
                data.CMolRNNLoader.from_numpy_to_tensor(input_i, j)
                for j, input_i in zip(gpu_ids, inputs)
            ]

            with autograd.record():
                loss = [(model(*input_i)).as_in_context(mx.gpu(gpu_ids[0]))
                        for input_i in inputs]
                loss = sum(loss) / len(gpu_ids)
                loss.backward()

            nd.waitall()
            gc.collect()

            trainer.step(batch_size=1)

            if global_counter % decay_step == 0:
                trainer.set_learning_rate(trainer.learning_rate *
                                          (1.0 - decay))

            if global_counter % summary_step == 0:
                if is_full:
                    loss = np.asscalar((sum(loss) / len(gpu_ids)).asnumpy())
                else:
                    del loss, inputs
                    gc.collect()

                    try:
                        inputs = [next(it_test) for _ in range(len(gpu_ids))]
                    except StopIteration:
                        it_test = iter(loader_test)
                        inputs = [next(it_test) for _ in range(len(gpu_ids))]

                    with autograd.predict_mode():
                        # move to gpu
                        inputs = [
                            data.CMolRNNLoader.from_numpy_to_tensor(
                                input_i, j)
                            for j, input_i in zip(gpu_ids, inputs)
                        ]
                        loss = [
                            (model(*input_i)).as_in_context(mx.gpu(gpu_ids[0]))
                            for input_i in inputs
                        ]
                        loss = np.asscalar(
                            (sum(loss) / len(gpu_ids)).asnumpy())

                model.save_params(os.path.join(ckpt_dir, 'ckpt.params'))
                trainer.save_states(os.path.join(ckpt_dir, 'trainer.status'))

                f.write('{}\t{}\t{}\t{}\n'.format(global_counter,
                                                  float(time.time() - t0) / 60,
                                                  loss, trainer.learning_rate))
                f.flush()

                del loss, inputs
                gc.collect()

            if global_counter >= iterations:
                break

        # save before exit
        model.save_params(os.path.join(ckpt_dir, 'ckpt.params'))
        trainer.save_states(os.path.join(ckpt_dir, 'trainer.status'))

        f.write('Training finished\n')