def _decode_step(X, A, NX, NA, last_action, finished, get_init, get_action, random=True, n_node_types=get_mol_spec().num_atom_types, n_edge_types=get_mol_spec().num_bond_types): if X is None: init = get_init() if random: X = [] for i in range(init.shape[0]): p = init[i, :] selected_atom = np.random.choice(np.arange(init.shape[1]), 1, p=p)[0] X.append(selected_atom) X = np.array(X, dtype=np.int32) else: X = np.argmax(init, axis=1) A = np.zeros((0, 3), dtype=np.int32) NX = last_action = np.ones([X.shape[0]], dtype=np.int32) NA = np.zeros([X.shape[0]], dtype=np.int32) finished = np.array([False, ] * X.shape[0], dtype=np.bool) return X, A, NX, NA, last_action, finished else: X_u = X[np.repeat(np.logical_not(finished), NX)] A_u = A[np.repeat(np.logical_not(finished), NA), :] NX_u = NX[np.logical_not(finished)] NA_u = NA[np.logical_not(finished)] last_action_u = last_action[np.logical_not(finished)] # conv mol_ids_rep = NX_rep = np.repeat(np.arange(NX_u.shape[0]), NX_u) rep_ids_rep = np.zeros_like(mol_ids_rep) if A.shape[0] == 0: D_2 = D_3 = np.zeros((0, 2), dtype=np.int32) A_u = [np.zeros((0, 2), dtype=np.int32) for _ in range(get_mol_spec().num_bond_types)] A_u += [D_2, D_3] else: cumsum = np.cumsum(np.pad(NX_u, [[1, 0]], mode='constant')[:-1]) shift = np.repeat(cumsum, NA_u) A_u[:, :2] += np.stack([shift, ] * 2, axis=1) D_2, D_3 = data.get_d(A_u, X_u) A_u = [A_u[A_u[:, 2] == _i, :2] for _i in range(n_edge_types)] A_u += [D_2, D_3] mask = np.zeros([X_u.shape[0]], dtype=np.int32) last_append_index = np.cumsum(NX_u) - 1 mask[last_append_index] = np.where(last_action_u == 1, np.ones_like(last_append_index, dtype=np.int32), np.ones_like(last_append_index, dtype=np.int32) * 2) decode_input = [X_u, A_u, NX_u, NX_rep, mask, mol_ids_rep, rep_ids_rep] append, connect, end = get_action(decode_input) if A.shape[0] == 0: max_index = np.argmax(np.reshape(append, [-1, n_node_types * n_edge_types]), axis=1) atom_type, bond_type = np.unravel_index(max_index, [n_node_types, n_edge_types]) X = np.reshape(np.stack([X, atom_type], axis=1), [-1]) NX = np.array([2, ] * len(finished), dtype=np.int32) A = np.stack([np.zeros([len(finished), ], dtype=np.int32), np.ones([len(finished), ], dtype=np.int32), bond_type], axis=1) NA = np.ones([len(finished), ], dtype=np.int32) last_action = np.ones_like(NX, dtype=np.int32) else: # process for each molecule append, connect = np.split(append, np.cumsum(NX_u)), np.split(connect, np.cumsum(NX_u)) end = end.tolist() unfinished_ids = np.where(np.logical_not(finished))[0].tolist() cumsum = np.cumsum(NX) cumsum_a = np.cumsum(NA) X_insert = [] X_insert_ids = [] A_insert = [] A_insert_ids = [] finished_ids = [] for i, (unfinished_id, append_i, connect_i, end_i) \ in enumerate(zip(unfinished_ids, append, connect, end)): if random: def _rand_id(*_x): _x_reshaped = [np.reshape(_xi, [-1]) for _xi in _x] _x_length = np.array([_x_reshape_i.shape[0] for _x_reshape_i in _x_reshaped], dtype=np.int32) _begin = np.cumsum(np.pad(_x_length, [[1, 0]], mode='constant')[:-1]) _end = np.cumsum(_x_length) - 1 _p = np.concatenate(_x_reshaped) _p = _p / np.sum(_p) _rand_index = np.random.choice(np.arange(_p.shape[0]), 1, p=_p)[0] _p_step = _p[_rand_index] _x_index = np.where(np.logical_and(_begin <= _rand_index, _end >= _rand_index))[0][0] _rand_index = _rand_index - _begin[_x_index] _rand_index = np.unravel_index(_rand_index, _x[_x_index].shape) return _x_index, _rand_index, _p_step action_type, action_index, p_step = _rand_id(append_i, connect_i, np.array([end_i])) else: _argmax = lambda _x: np.unravel_index(np.argmax(_x), _x.shape) append_id, append_val = _argmax(append_i), np.max(append_i) connect_id, connect_val = _argmax(connect_i), np.max(connect_i) end_val = end_i if end_val >= append_val and end_val >= connect_val: action_type = 2 action_index = None elif append_val >= connect_val and append_val >= end_val: action_type = 0 action_index = append_id else: action_type = 1 action_index = connect_id if action_type == 2: # finish growth finished_ids.append(unfinished_id) elif action_type == 0: # append action append_pos, atom_type, bond_type = action_index X_insert.append(atom_type) X_insert_ids.append(unfinished_id) A_insert.append([append_pos, NX[unfinished_id], bond_type]) A_insert_ids.append(unfinished_id) else: # connect connect_ps, bond_type = action_index A_insert.append([NX[unfinished_id] - 1, connect_ps, bond_type]) A_insert_ids.append(unfinished_id) if len(A_insert_ids) > 0: A = np.insert(A, cumsum_a[A_insert_ids], A_insert, axis=0) NA[A_insert_ids] += 1 last_action[A_insert_ids] = 0 if len(X_insert_ids) > 0: X = np.insert(X, cumsum[X_insert_ids], X_insert, axis=0) NX[X_insert_ids] += 1 last_action[X_insert_ids] = 1 if len(finished_ids) > 0: finished[finished_ids] = True # print finished return X, A, NX, NA, last_action, finished
def _get_model(configs): return models.CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs)
def _engine_cond(cond_type='scaffold', file_name='datasets/ChEMBL_scaffold.txt', num_scaffolds=734, is_full=False, ckpt_dir='ckpt/scaffold', num_folds=5, fold_id=0, batch_size=50, batch_size_test=100, num_workers=2, k=5, p=0.8, F_e=16, F_h=(32, 64, 128, 128, 256, 256), F_skip=256, F_c=(512, ), Fh_policy=128, activation='relu', N_rnn=3, gpu_ids=(0, 1, 2, 3), lr=1e-3, decay=0.015, decay_step=100, clip_grad=3.0, iterations=30000, summary_step=200): if all([ os.path.isfile(os.path.join(ckpt_dir, _n)) for _n in ['log.out', 'ckpt.params', 'trainer.status'] ]): is_continuous = True else: is_continuous = False if is_full: if cond_type != 'kinase': if cond_type == 'scaffold': cond = data.SparseFP(num_scaffolds) N_C = num_scaffolds elif cond_type == 'prop': cond = data.Delimited() N_C = 2 else: raise ValueError with open(file_name) as f: dataset = data.Lambda(f.readlines(), lambda _x: _x.strip('\n').strip('\r')) # get sampler and loader for training set sampler_train = data.BalancedSampler( cost=[len(l.split('\t')[0]) for l in dataset], batch_size=batch_size) loader_train = data.CMolRNNLoader(dataset, batch_sampler=sampler_train, num_workers=num_workers, k=k, p=p, conditional=cond) loader_test = [] else: cond = data.Delimited() N_C = 2 if all([ os.path.isfile(os.path.join(ckpt_dir, _n)) for _n in ['log.out', 'ckpt.params', 'trainer.status'] ]): is_continuous = True else: is_continuous = False with open(file_name) as f: dataset = data.Lambda(f.readlines(), lambda _x: _x.strip('\n').strip('\r')) # get dataset def _filter(_line, _i): return int(_line.split('\t')[-1]) == _i db_train = data.Lambda(data.Filter( dataset, fn=lambda _x: not _filter(_x, fold_id)), fn=lambda _x: _x[:-2]) db_test = data.Lambda(data.Filter( dataset, fn=lambda _x: _filter(_x, fold_id)), fn=lambda _x: _x[:-2]) # get sampler and loader for test set loader_test = data.CMolRNNLoader(db_test, shuffle=True, num_workers=num_workers, k=k, p=p, conditional=cond, batch_size=batch_size_test) # get sampler and loader for training set loader_train = data.CMolRNNLoader(db_train, shuffle=True, num_workers=num_workers, k=k, p=p, conditional=cond, batch_size=batch_size) # get iterator it_train, it_test = iter(loader_train), iter(loader_test) else: if cond_type != 'kinase': if cond_type == 'scaffold': cond = data.SparseFP(num_scaffolds) N_C = num_scaffolds elif cond_type == 'prop': cond = data.Delimited() N_C = 2 else: raise ValueError if all([ os.path.isfile(os.path.join(ckpt_dir, _n)) for _n in ['log.out', 'ckpt.params', 'trainer.status'] ]): is_continuous = True else: is_continuous = False with open(file_name) as f: dataset = data.Lambda(f.readlines(), lambda _x: _x.strip('\n').strip('\r')) # get dataset db_train = data.KFold(dataset, k=num_folds, fold_id=fold_id, is_train=True) db_test = data.KFold(dataset, k=num_folds, fold_id=fold_id, is_train=False) # get sampler and loader for training set sampler_train = data.BalancedSampler( cost=[len(l.split('\t')[0]) for l in db_train], batch_size=batch_size) loader_train = data.CMolRNNLoader(db_train, batch_sampler=sampler_train, num_workers=num_workers, k=k, p=p, conditional=cond) # get sampler and loader for test set sampler_test = data.BalancedSampler( cost=[len(l.split('\t'[0])) for l in db_test], batch_size=batch_size_test) loader_test = data.CMolRNNLoader(db_test, batch_sampler=sampler_test, num_workers=num_workers, k=k, p=p, conditional=cond) else: cond = data.Delimited() N_C = 2 if all([ os.path.isfile(os.path.join(ckpt_dir, _n)) for _n in ['log.out', 'ckpt.params', 'trainer.status'] ]): is_continuous = True else: is_continuous = False with open(file_name) as f: dataset = data.Lambda(f.readlines(), lambda _x: _x.strip('\n').strip('\r')) # get dataset def _filter(_line, _i): return int(_line.split('\t')[-1]) == _i db_train = data.Lambda(data.Filter( dataset, fn=lambda _x: not _filter(_x, fold_id)), fn=lambda _x: _x[:-2]) db_test = data.Lambda(data.Filter( dataset, fn=lambda _x: _filter(_x, fold_id)), fn=lambda _x: _x[:-2]) # get sampler and loader for training set loader_train = data.CMolRNNLoader(db_train, shuffle=True, num_workers=num_workers, k=k, p=p, conditional=cond, batch_size=batch_size) # get sampler and loader for test set loader_test = data.CMolRNNLoader(db_test, shuffle=True, num_workers=num_workers, k=k, p=p, conditional=cond, batch_size=batch_size_test) # get iterator it_train, it_test = iter(loader_train), iter(loader_test) # build model if not is_continuous: configs = { 'N_C': N_C, 'F_e': F_e, 'F_h': F_h, 'F_skip': F_skip, 'F_c': F_c, 'Fh_policy': Fh_policy, 'activation': activation, 'rename': True, 'N_rnn': N_rnn } with open(os.path.join(ckpt_dir, 'configs.json'), 'w') as f: json.dump(configs, f) else: with open(os.path.join(ckpt_dir, 'configs.json')) as f: configs = json.load(f) model = models.CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs) ctx = [mx.gpu(i) for i in gpu_ids] model.collect_params().initialize(mx.init.Xavier(), force_reinit=True, ctx=ctx) if not is_continuous: if cond_type == 'kinase': model.load_params(os.path.join(ckpt_dir, 'ckpt.params.bk'), ctx=ctx, allow_missing=True) else: model.load_params(os.path.join(ckpt_dir, 'ckpt.params'), ctx=ctx) # construct optimizer opt = mx.optimizer.Adam(learning_rate=lr, clip_gradient=clip_grad) trainer = gluon.Trainer(model.collect_params(), opt) if is_continuous: trainer.load_states(os.path.join(ckpt_dir, 'trainer.status')) if not is_continuous: t0 = time.time() global_counter = 0 else: with open(os.path.join(ckpt_dir, 'log.out')) as f: records = f.readlines() if records[-1] != 'Training finished\n': final_record = records[-1] else: final_record = records[-2] count, t_final = int(final_record.split('\t')[0]), float( final_record.split('\t')[1]) t0 = time.time() - t_final * 60 global_counter = count with open(os.path.join(ckpt_dir, 'log.out'), mode='w' if not is_continuous else 'a') as f: if not is_continuous: f.write('step\ttime(h)\tloss\tlr\n') while True: global_counter += 1 try: inputs = [next(it_train) for _ in range(len(gpu_ids))] except StopIteration: it_train = iter(loader_train) inputs = [next(it_train) for _ in range(len(gpu_ids))] # move to gpu inputs = [ data.CMolRNNLoader.from_numpy_to_tensor(input_i, j) for j, input_i in zip(gpu_ids, inputs) ] with autograd.record(): loss = [(model(*input_i)).as_in_context(mx.gpu(gpu_ids[0])) for input_i in inputs] loss = sum(loss) / len(gpu_ids) loss.backward() nd.waitall() gc.collect() trainer.step(batch_size=1) if global_counter % decay_step == 0: trainer.set_learning_rate(trainer.learning_rate * (1.0 - decay)) if global_counter % summary_step == 0: if is_full: loss = np.asscalar((sum(loss) / len(gpu_ids)).asnumpy()) else: del loss, inputs gc.collect() try: inputs = [next(it_test) for _ in range(len(gpu_ids))] except StopIteration: it_test = iter(loader_test) inputs = [next(it_test) for _ in range(len(gpu_ids))] with autograd.predict_mode(): # move to gpu inputs = [ data.CMolRNNLoader.from_numpy_to_tensor( input_i, j) for j, input_i in zip(gpu_ids, inputs) ] loss = [ (model(*input_i)).as_in_context(mx.gpu(gpu_ids[0])) for input_i in inputs ] loss = np.asscalar( (sum(loss) / len(gpu_ids)).asnumpy()) model.save_params(os.path.join(ckpt_dir, 'ckpt.params')) trainer.save_states(os.path.join(ckpt_dir, 'trainer.status')) f.write('{}\t{}\t{}\t{}\n'.format(global_counter, float(time.time() - t0) / 60, loss, trainer.learning_rate)) f.flush() del loss, inputs gc.collect() if global_counter >= iterations: break # save before exit model.save_params(os.path.join(ckpt_dir, 'ckpt.params')) trainer.save_states(os.path.join(ckpt_dir, 'trainer.status')) f.write('Training finished\n')