def test_hypergraph_mask_gen(self): molecules = True grammar_cache = 'tmp.pickle' grammar = 'hypergraph:' + grammar_cache # create a grammar cache inferred from our sample molecules g = HypergraphGrammar(cache_file=grammar_cache) if os.path.isfile(g.cache_file): os.remove(g.cache_file) g.strings_to_actions(get_zinc_smiles(5)) mask_gen1 = get_codec(molecules, grammar, 30).mask_gen mask_gen2 = get_codec(molecules, grammar, 30).mask_gen mask_gen1.priors = False mask_gen2.priors = True policy1 = SoftmaxRandomSamplePolicy( bias=mask_gen1.grammar.get_log_frequencies()) policy2 = SoftmaxRandomSamplePolicy() lp = [] for mg in [mask_gen1, mask_gen2]: mg.reset() mg.apply_action([None]) logit_priors = mg.action_prior_logits() # that includes any priors lp.append( torch.from_numpy(logit_priors).to(device=device, dtype=torch.float32)) dummy_model_output = torch.ones_like(lp[0]) eff_logits = [] for this_lp, policy in zip(lp, [policy1, policy2]): eff_logits.append(policy.effective_logits(dummy_model_output)) assert torch.max((eff_logits[0] - eff_logits[1]).abs()) < 1e-6
def test_classic_mask_gen_equations(self): molecules = False grammar = 'classic' codec = get_codec(molecules, grammar, max_seq_length) actions = run_random_gen(codec.mask_gen) all_eqs = codec.actions_to_strings(actions) # the only way of testing correctness we have here is whether the equations parse correctly parsed_eqs = codec.strings_to_actions(all_eqs)
def test_classic_mask_gen_molecules(self): molecules = True grammar = 'classic' codec = get_codec(molecules, grammar, max_seq_length) actions = run_random_gen(codec.mask_gen) new_smiles = codec.actions_to_strings(actions) # the SMILES produced by that grammar are NOT guaranteed to be valid, # so can only check that the decoding completes without errors and is grammatically valid parsed_smiles = codec.strings_to_actions(new_smiles)
def get_node_decoder(grammar, max_seq_length=15, drop_rate=0.0, decoder_type='attn', rule_policy=None, reward_fun=lambda x: -1 * np.ones(len(x)), batch_size=None, priors='conditional', bins=10): codec = get_codec(True, grammar, max_seq_length) assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type" if 'attn' in decoder_type: model_type = 'transformer' elif 'rnn' in decoder_type: model_type = 'rnn' if 'distr' in decoder_type: if 'softmax' in decoder_type: output_type = 'distributions_softmax' else: output_type = 'distributions_thompson' else: output_type = 'values' model = get_graph_model(codec, drop_rate, model_type, output_type, num_bins=bins) # encoder = GraphEncoder(grammar=codec.grammar, # d_model=512, # drop_rate=drop_rate, # model_type=model_type) # # # model = MultipleOutputHead(model=encoder, # output_spec={'node': 1, # to be used to select next node to expand # 'action': codec.feature_len()}, # to select the action for chosen node # drop_rate=drop_rate) # don't support using this model in VAE-style models yet mask_gen = HypergraphMaskGenerator(max_len=max_seq_length, grammar=codec.grammar) mask_gen.priors = priors if rule_policy is None: rule_policy = SoftmaxRandomSamplePolicy() stepper = GraphDecoderWithNodeSelection(model, rule_policy=rule_policy) env = GraphEnvironment(mask_gen, reward_fun=reward_fun, batch_size=batch_size) decoder = DecoderWithEnvironmentNew(stepper, env) return decoder, stepper
def get_node_decoder(grammar, max_seq_length=15, drop_rate=0.0, decoder_type='attn', rule_policy=None, reward_fun=lambda x: -1 * np.ones(len(x)), batch_size=None, priors='conditional', bins=10): codec = get_codec(True, grammar, max_seq_length) assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type" if 'attn' in decoder_type: model_type = 'transformer' elif 'rnn' in decoder_type: model_type = 'rnn' elif 'conditional' in decoder_type: if 'sparse' in decoder_type: model_type = 'conditional_sparse' else: model_type = 'conditional' if 'distr' in decoder_type: if 'softmax' in decoder_type: output_type = 'distributions_softmax' else: output_type = 'distributions_thompson' else: output_type = 'values' model = get_graph_model(codec, drop_rate, model_type, output_type, num_bins=bins) if model_type == 'conditional_sparse': priors = 'term_dist_only' mask_gen = HypergraphMaskGenerator(max_len=max_seq_length, grammar=codec.grammar, priors=priors) mask_gen.priors = priors if rule_policy is None: rule_policy = SoftmaxRandomSamplePolicySparse( ) if 'sparse' in decoder_type else SoftmaxRandomSamplePolicy() stepper_type = GraphDecoderWithNodeSelectionSparse if 'sparse' in decoder_type else GraphDecoderWithNodeSelection stepper = stepper_type(model, rule_policy=rule_policy) env = GraphEnvironment(mask_gen, reward_fun=reward_fun, batch_size=batch_size) decoder = DecoderWithEnvironmentNew(stepper, env) return decoder, stepper
def test_hypergraph_mask_gen(self): molecules = True grammar_cache = 'tmp.pickle' grammar = 'hypergraph:' + grammar_cache # create a grammar cache inferred from our sample molecules g = HypergraphGrammar(cache_file=grammar_cache) if os.path.isfile(g.cache_file): os.remove(g.cache_file) g.strings_to_actions(smiles) codec = get_codec(molecules, grammar, max_seq_length) self.generate_and_validate(codec)
def test_graph_encoder_with_head(self): codec = get_codec(molecules=True, grammar='hypergraph:' + tmp_file, max_seq_length=max_seq_length) encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)] model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()}, # to select the action for chosen node drop_rate=0.1).to(device) out = model(mol_graphs)
def get_vae(molecules=True, grammar=True, weights_file=None, epsilon_std=1, decoder_type='step', **kwargs): model_args = get_model_args(molecules=molecules, grammar=grammar) for key, value in kwargs.items(): if key in model_args: model_args[key] = value sample_z = model_args.pop('sample_z') encoder_args = [ 'feature_len', 'max_seq_length', 'cnn_encoder_params', 'drop_rate', 'encoder_type', 'rnn_encoder_hidden_n' ] encoder = get_encoder(**{ key: value for key, value in model_args.items() if key in encoder_args }) decoder_args = [ 'z_size', 'decoder_hidden_n', 'feature_len', 'max_seq_length', 'drop_rate', 'batch_size' ] decoder, _ = get_decoder(molecules, grammar, decoder_type=decoder_type, **{ key: value for key, value in model_args.items() if key in decoder_args }) model = generative_playground.models.heads.vae.VariationalAutoEncoderHead( encoder=encoder, decoder=decoder, sample_z=sample_z, epsilon_std=epsilon_std, z_size=model_args['z_size']) if weights_file is not None: model.load(weights_file) settings = get_settings(molecules=molecules, grammar=grammar) codec = get_codec(molecules, grammar, max_seq_length=settings['max_seq_length']) # codec.set_model(model) # todo do we ever use this? return model, codec
def generic_decoder_test(self, decoder_type, grammar): codec = get_codec(molecules=True, grammar=grammar, max_seq_length=max_seq_length) decoder, pre_decoder = get_decoder(decoder_type=decoder_type, max_seq_length=max_seq_length, grammar=grammar, feature_len=codec.feature_len(), z_size=z_size, batch_size=batch_size) out = decoder() # it returns all sorts of things: out_actions_all, out_logits_all, out_rewards_all, out_terminals_all, (info[0], to_pytorch(info[1])) all_sum = torch.sum(out['logits']) all_sum.backward() return all_sum
def get_thompson_globals( num_bins=50, # TODO: replace with a Value Distribution object reward_fun_=None, grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle', # 'hyper_grammar.pickle' max_seq_length=60, decay=0.95, updates_to_refresh=10): grammar_name = 'hypergraph:' + grammar_cache codec = get_codec(True, grammar_name, max_seq_length) reward_proc = RewardProcessor(num_bins) rule_choice_repo_factory = lambda x: RuleChoiceRepository( reward_proc=reward_proc, mask=x, decay=decay) exp_repo_ = ExperienceRepository( grammar=codec.grammar, reward_preprocessor=reward_proc, decay=decay, conditional_keys=[ key for key in codec.grammar.conditional_frequencies.keys() ], rule_choice_repo_factory=rule_choice_repo_factory) # TODO: weave this into the nodes to do node-level action averages as regularization local_exp_repo_factory = lambda graph: ExperienceRepository( grammar=codec.grammar, reward_preprocessor=reward_proc, decay=decay, conditional_keys=[i for i in range(len(graph))], rule_choice_repo_factory=rule_choice_repo_factory) globals = GlobalParametersThompson( codec.grammar, max_seq_length, exp_repo_, decay=decay, updates_to_refresh=updates_to_refresh, reward_fun=reward_fun_, reward_proc=reward_proc, rule_choice_repo_factory=rule_choice_repo_factory, state_store=None) return globals
def __init__( self, batch_size=20, reward_fun_=None, grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle', # 'hyper_grammar.pickle' max_depth=60, lr=0.05, grad_clip=5, entropy_weight=3, decay=None, num_bins=None, updates_to_refresh=None, plotter=None, degenerate=False # use a null model if true ): grammar_name = 'hypergraph:' + grammar_cache codec = get_codec(True, grammar_name, max_depth) super().__init__(codec.grammar, max_depth, reward_fun_, {}, plotter=plotter) if not degenerate: # create optimizer factory optimizer_factory = optimizer_factory_gen(lr, grad_clip) # create model model = CondtionalProbabilityModel(codec.grammar).to(device) # create loss object loss_type = 'advantage_record' loss_fun = PolicyGradientLoss(loss_type, entropy_wgt=entropy_weight) self.model = model self.process_reward = MCTSRewardProcessor(loss_fun, model, optimizer_factory, batch_size) else: self.model = PassthroughModel() self.process_reward = lambda reward, log_ps, actions, params: None self.decay = decay self.reward_proc = RewardProcessor(num_bins) self.updates_to_refresh = updates_to_refresh
def get_model_args(molecules, grammar, drop_rate=0.5, sample_z=False, encoder_type='rnn'): settings = get_settings(molecules, grammar) codec = get_codec(molecules, grammar, settings['max_seq_length']) model_args = { 'z_size': settings['z_size'], 'decoder_hidden_n': settings['decoder_hidden_n'], 'feature_len': codec.feature_len(), 'max_seq_length': settings['max_seq_length'], 'cnn_encoder_params': settings['cnn_encoder_params'], 'drop_rate': drop_rate, 'sample_z': sample_z, 'encoder_type': encoder_type, 'rnn_encoder_hidden_n': settings['rnn_encoder_hidden_n'] } return model_args
def __init__(self, molecules=True, grammar=True, reward_fun=None, batch_size=1, max_steps=None, save_dataset=None): settings = get_settings(molecules, grammar) self.codec = get_codec(molecules, grammar, settings['max_seq_length']) self.action_dim = self.codec.feature_len() self.state_dim = self.action_dim if max_steps is None: self._max_episode_steps = settings['max_seq_length'] else: self._max_episode_steps = max_steps self.reward_fun = reward_fun self.batch_size = batch_size self.save_dataset = save_dataset self.smiles = None self.seq_len = None self.valid = None self.actions = None self.done_rewards = None self.reset()
def train_policy_gradient(molecules=True, grammar=True, smiles_source='ZINC', EPOCHS=None, BATCH_SIZE=None, reward_fun_on=None, reward_fun_off=None, max_steps=277, lr_on=2e-4, lr_discrim=1e-4, lr_schedule=None, discrim_wt=2, p_thresh=0.5, drop_rate=0.0, plot_ignore_initial=0, randomize_reward=False, save_file_root_name=None, reward_sm=0.0, preload_file_root_name=None, anchor_file=None, anchor_weight=0.0, decoder_type='action', plot_prefix='', dashboard='policy gradient', smiles_save_file=None, on_policy_loss_type='best', priors=True, node_temperature_schedule=lambda x: 1.0, rule_temperature_schedule=lambda x: 1.0, eps=0.0, half_float=False, extra_repetition_penalty=0.0, entropy_wgt=1.0): root_location = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) root_location = root_location + '/../../' def full_path(x): return os.path.realpath(root_location + 'pretrained/' + x) zinc_data = get_smiles_from_database(source=smiles_source) zinc_set = set(zinc_data) lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE] history_data = [deque(['O'], maxlen=lb) for lb in lookbacks] if save_file_root_name is not None: gen_save_file = save_file_root_name + '_gen.h5' disc_save_file = save_file_root_name + '_disc.h5' if preload_file_root_name is not None: gen_preload_file = preload_file_root_name + '_gen.h5' disc_preload_file = preload_file_root_name + '_disc.h5' settings = get_settings(molecules=molecules, grammar=grammar) codec = get_codec(molecules, grammar, settings['max_seq_length']) discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate) if False and preload_file_root_name is not None: try: preload_path = full_path(disc_preload_file) discrim_model.load_state_dict(torch.load(preload_path), strict=False) print('Discriminator weights loaded successfully!') except Exception as e: print('failed to load discriminator weights ' + str(e)) if EPOCHS is not None: settings['EPOCHS'] = EPOCHS if BATCH_SIZE is not None: settings['BATCH_SIZE'] = BATCH_SIZE alt_reward_calc = AdjustedRewardCalculator(reward_fun_on, zinc_set, lookbacks, extra_repetition_penalty, discrim_wt, discrim_model=None) reward_fun = lambda x: adj_reward(discrim_wt, discrim_model, reward_fun_on, zinc_set, history_data, extra_repetition_penalty, x, alt_calc=alt_reward_calc) task = SequenceGenerationTask(molecules=molecules, grammar=grammar, reward_fun=reward_fun, batch_size=BATCH_SIZE, max_steps=max_steps, save_dataset=None) node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0), eps=eps) rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(2.0), eps=eps) model = get_decoder(molecules, grammar, z_size=settings['z_size'], decoder_hidden_n=200, feature_len=codec.feature_len(), max_seq_length=max_steps, drop_rate=drop_rate, batch_size=BATCH_SIZE, decoder_type=decoder_type, reward_fun=reward_fun, task=task, node_policy=node_policy, rule_policy=rule_policy, priors=priors)[0] if preload_file_root_name is not None: try: preload_path = full_path(gen_preload_file) model.load_state_dict(torch.load(preload_path, map_location='cpu'), strict=False) print('Generator weights loaded successfully!') except Exception as e: print('failed to load generator weights ' + str(e)) anchor_model = None # construct the loader to feed the discriminator def make_callback(data): def hc(inputs, model, outputs, loss_fn, loss): graphs = outputs['graphs'] smiles = [g.to_smiles() for g in graphs] for s in smiles: # only store unique instances of molecules so discriminator can't guess on frequency if s not in data: data.append(s) return hc class TemperatureCallback: def __init__(self, policy, temperature_function): self.policy = policy self.counter = 0 self.temp_fun = temperature_function def __call__(self, inputs, model, outputs, loss_fn, loss): self.counter += 1 target_temp = self.temp_fun(self.counter) self.policy.set_temperature(target_temp) # need to have something there to begin with, else the DataLoader constructor barfs def get_rl_fitter(model, loss_obj, train_gen, save_path, fit_plot_prefix='', model_process_fun=None, lr=None, lr_schedule=lr_schedule, extra_callbacks=[], loss_display_cap=float('inf'), anchor_model=None, anchor_weight=0): nice_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4) if lr_schedule is None: lr_schedule = lambda x: 1.0 scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule) if dashboard is not None: metric_monitor = MetricPlotter( plot_prefix=fit_plot_prefix, loss_display_cap=loss_display_cap, dashboard_name=dashboard, plot_ignore_initial=plot_ignore_initial, process_model_fun=model_process_fun, smooth_weight=reward_sm, save_location=os.path.dirname(save_path)) else: metric_monitor = None checkpointer = Checkpointer(valid_batches_to_checkpoint=10, save_path=save_path, save_always=True) fitter = fit_rl(train_gen=train_gen, model=model, optimizer=optimizer, scheduler=scheduler, epochs=EPOCHS, loss_fn=loss_obj, grad_clip=5, half_float=half_float, anchor_model=anchor_model, anchor_weight=anchor_weight, callbacks=[metric_monitor, checkpointer] + extra_callbacks) return fitter class GeneratorToIterable: def __init__(self, gen): self.gen = gen # we assume the generator is finite self.len = 0 for _ in gen(): self.len += 1 def __len__(self): return self.len def __iter__(self): return self.gen() def my_gen(): for _ in range(1000): yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size'])) # the on-policy fitter gen_extra_callbacks = [make_callback(d) for d in history_data] if smiles_save_file is not None: smiles_save_path = os.path.realpath(root_location + 'pretrained/' + smiles_save_file) gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True)) print('Saved SMILES to {}'.format(smiles_save_file)) if node_temperature_schedule is not None: gen_extra_callbacks.append( TemperatureCallback(node_policy, node_temperature_schedule)) if rule_temperature_schedule is not None: gen_extra_callbacks.append( TemperatureCallback(rule_policy, rule_temperature_schedule)) fitter1 = get_rl_fitter( model, PolicyGradientLoss( on_policy_loss_type, entropy_wgt=entropy_wgt), # last_reward_wgt=reward_sm), GeneratorToIterable(my_gen), full_path(gen_save_file), plot_prefix + 'on-policy', model_process_fun=model_process_fun, lr=lr_on, extra_callbacks=gen_extra_callbacks, anchor_model=anchor_model, anchor_weight=anchor_weight) # # # get existing molecule data to add training pre_dataset = EvenlyBlendedDataset( 2 * [history_data[0]] + history_data[1:], labels=False) # a blend of 3 time horizons dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True) discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50) class MyLoss(nn.Module): def __init__(self): super().__init__() self.celoss = nn.CrossEntropyLoss() def forward(self, x): # tmp = discriminator_reward_mult(x['smiles']) # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy() # import numpy as np # assert np.max(np.abs(tmp-tmp2)) < 1e-6 return self.celoss(x['p_zinc'].to(device), x['dataset_index'].to(device)) fitter2 = get_rl_fitter(discrim_model, MyLoss(), IterableTransform( discrim_loader, lambda x: { 'smiles': x['X'], 'dataset_index': x['dataset_index'] }), full_path(disc_save_file), plot_prefix + ' discriminator', lr=lr_discrim, model_process_fun=None) def on_policy_gen(fitter, model): while True: # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies()) yield next(fitter) return model, fitter1, fitter2 #,on_policy_gen(fitter1, model)
def __init__(self, grammar, smiles_source='ZINC', BATCH_SIZE=None, reward_fun=None, max_steps=277, num_batches=100, lr=2e-4, entropy_wgt=1.0, lr_schedule=None, root_name=None, preload_file_root_name=None, save_location=None, plot_metrics=True, metric_smooth=0.0, decoder_type='graph_conditional', on_policy_loss_type='advantage_record', priors='conditional', rule_temperature_schedule=None, eps=0.0, half_float=False, extra_repetition_penalty=0.0): self.num_batches = num_batches self.save_location = save_location self.molecule_saver = MoleculeSaver(None, gzip=True) self.metric_monitor = None # to be populated by self.set_root_name(...) zinc_data = get_smiles_from_database(source=smiles_source) zinc_set = set(zinc_data) lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE] history_data = [deque(['O'], maxlen=lb) for lb in lookbacks] if root_name is not None: pass # gen_save_file = root_name + '_gen.h5' if preload_file_root_name is not None: gen_preload_file = preload_file_root_name + '_gen.h5' settings = get_settings(molecules=True, grammar=grammar) codec = get_codec(True, grammar, settings['max_seq_length']) if BATCH_SIZE is not None: settings['BATCH_SIZE'] = BATCH_SIZE self.alt_reward_calc = AdjustedRewardCalculator( reward_fun, zinc_set, lookbacks, extra_repetition_penalty, 0, discrim_model=None) self.reward_fun = lambda x: adj_reward(0, None, reward_fun, zinc_set, history_data, extra_repetition_penalty, x, alt_calc=self.alt_reward_calc) task = SequenceGenerationTask(molecules=True, grammar=grammar, reward_fun=self.alt_reward_calc, batch_size=BATCH_SIZE, max_steps=max_steps, save_dataset=None) if 'sparse' in decoder_type: rule_policy = SoftmaxRandomSamplePolicySparse() else: rule_policy = SoftmaxRandomSamplePolicy( temperature=torch.tensor(1.0), eps=eps) # TODO: strip this down to the normal call self.model = get_decoder(True, grammar, z_size=settings['z_size'], decoder_hidden_n=200, feature_len=codec.feature_len(), max_seq_length=max_steps, batch_size=BATCH_SIZE, decoder_type=decoder_type, reward_fun=self.alt_reward_calc, task=task, rule_policy=rule_policy, priors=priors)[0] if preload_file_root_name is not None: try: preload_path = os.path.realpath(save_location + gen_preload_file) self.model.load_state_dict(torch.load(preload_path, map_location='cpu'), strict=False) print('Generator weights loaded successfully!') except Exception as e: print('failed to load generator weights ' + str(e)) # construct the loader to feed the discriminator def make_callback(data): def hc(inputs, model, outputs, loss_fn, loss): graphs = outputs['graphs'] smiles = [g.to_smiles() for g in graphs] for s in smiles: # only store unique instances of molecules so discriminator can't guess on frequency if s not in data: data.append(s) return hc if plot_metrics: # TODO: save_file for rewards data goes here? self.metric_monitor_factory = lambda name: MetricPlotter( plot_prefix='', loss_display_cap=float('inf'), dashboard_name=name, save_location=save_location, process_model_fun=model_process_fun, smooth_weight=metric_smooth) else: self.metric_monitor_factory = lambda x: None # the on-policy fitter gen_extra_callbacks = [make_callback(d) for d in history_data] gen_extra_callbacks.append(self.molecule_saver) if rule_temperature_schedule is not None: gen_extra_callbacks.append( TemperatureCallback(rule_policy, rule_temperature_schedule)) nice_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4) if lr_schedule is None: lr_schedule = lambda x: 1.0 self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_schedule) self.loss = PolicyGradientLoss(on_policy_loss_type, entropy_wgt=entropy_wgt) self.fitter_factory = lambda: make_fitter(BATCH_SIZE, settings[ 'z_size'], [self.metric_monitor] + gen_extra_callbacks, self) self.fitter = self.fitter_factory() self.set_root_name(root_name) print('Runner initialized!')
def get_decoder( molecules=True, grammar=True, z_size=200, decoder_hidden_n=200, feature_len=12, # TODO: remove this max_seq_length=15, drop_rate=0.0, decoder_type='step', task=None, batch_size=None): codec = get_codec(molecules, grammar, max_seq_length) if decoder_type == 'old': stepper = ResettingRNNDecoder(z_size=z_size, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, steps=max_seq_length, drop_rate=drop_rate) stepper = OneStepDecoderContinuous(stepper) elif decoder_type == 'attn_graph': assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type" encoder = GraphEncoder(grammar=codec.grammar, d_model=512, drop_rate=drop_rate) model = MultipleOutputHead( model=encoder, output_spec={ 'node': 1, # to be used to select next node to expand 'action': codec.feature_len() }, # to select the action for chosen node drop_rate=drop_rate) # don't support using this model in VAE-style scenarios yet model.init_encoder_output = lambda x: None mask_gen = HypergraphMaskGenerator(max_len=max_seq_length, grammar=codec.grammar) stepper = GraphDecoder(model=model, mask_gen=mask_gen) else: if decoder_type == 'step': stepper = SimpleRNNDecoder(z_size=z_size, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, drop_rate=drop_rate, use_last_action=False) elif decoder_type == 'action': stepper = SimpleRNNDecoder( z_size=z_size, # + feature_len, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, drop_rate=drop_rate, use_last_action=True) elif decoder_type == 'action_resnet': stepper = ResNetRNNDecoder( z_size=z_size, # + feature_len, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, drop_rate=drop_rate, use_last_action=True) elif decoder_type == 'attention': stepper = SelfAttentionDecoderStep(num_actions=codec.feature_len(), max_seq_len=max_seq_length, drop_rate=drop_rate, enc_output_size=z_size) elif decoder_type == 'random': stepper = RandomDecoder(feature_len=codec.feature_len(), max_seq_length=max_seq_length) else: raise NotImplementedError('Unknown decoder type: ' + str(decoder_type)) if grammar is not False: # add a masking layer mask_gen = get_codec(molecules, grammar, max_seq_length).mask_gen stepper = MaskingHead(stepper, mask_gen) policy = SoftmaxRandomSamplePolicy( ) #bias=codec.grammar.get_log_frequencies()) decoder = to_gpu( SimpleDiscreteDecoderWithEnv( stepper, policy, task=task, batch_size=batch_size)) # , bypass_actions=True)) return decoder, stepper
def train_policy_gradient_ppo(molecules=True, grammar=True, smiles_source='ZINC', EPOCHS=None, BATCH_SIZE=None, reward_fun_on=None, reward_fun_off=None, max_steps=277, lr_on=2e-4, lr_discrim=1e-4, discrim_wt=2, p_thresh=0.5, drop_rate=0.0, plot_ignore_initial=0, randomize_reward=False, save_file_root_name=None, reward_sm=0.0, preload_file_root_name=None, anchor_file=None, anchor_weight=0.0, decoder_type='action', plot_prefix='', dashboard='policy gradient', smiles_save_file=None, on_policy_loss_type='best', priors=True, node_temperature_schedule=lambda x: 1.0, rule_temperature_schedule=lambda x: 1.0, eps=0.0, half_float=False, extra_repetition_penalty=0.0): root_location = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) root_location = root_location + '/../../' def full_path(x): return os.path.realpath(root_location + 'pretrained/' + x) if save_file_root_name is not None: gen_save_file = save_file_root_name + '_gen.h5' disc_save_file = save_file_root_name + '_disc.h5' if preload_file_root_name is not None: gen_preload_file = preload_file_root_name + '_gen.h5' disc_preload_file = preload_file_root_name + '_disc.h5' settings = get_settings(molecules=molecules, grammar=grammar) codec = get_codec(molecules, grammar, settings['max_seq_length']) discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate) if False and preload_file_root_name is not None: try: preload_path = full_path(disc_preload_file) discrim_model.load_state_dict(torch.load(preload_path), strict=False) print('Discriminator weights loaded successfully!') except Exception as e: print('failed to load discriminator weights ' + str(e)) zinc_data = get_smiles_from_database(source=smiles_source) zinc_set = set(zinc_data) lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE] history_data = [deque(['O'], maxlen=lb) for lb in lookbacks] def originality_mult(smiles_list): out = [] for s in smiles_list: if s in zinc_set: out.append(0.5) elif s in history_data[0]: out.append(0.5) elif s in history_data[1]: out.append(0.70) elif s in history_data[2]: out.append(0.85) else: out.append(1.0) return np.array(out) def sigmoid(x): return 1 / (1 + np.exp(-x)) def discriminator_reward_mult(smiles_list): orig_state = discrim_model.training discrim_model.eval() discrim_out_logits = discrim_model(smiles_list)['p_zinc'] discrim_probs = F.softmax(discrim_out_logits, dim=1) prob_zinc = discrim_probs[:, 1].detach().cpu().numpy() if orig_state: discrim_model.train() return prob_zinc def apply_originality_penalty(x, orig_mult): assert x <= 1, "Reward must be no greater than 0" if x > 0.5: # want to punish nearly-perfect scores less and less out = math.pow(x, 1 / orig_mult) else: # continuous join at 0.5 penalty = math.pow(0.5, 1 / orig_mult) - 0.5 out = x + penalty out -= extra_repetition_penalty * (1 - 1 / orig_mult) return out def adj_reward(x): if discrim_wt > 1e-5: p = discriminator_reward_mult(x) else: p = 0 rwd = np.array(reward_fun_on(x)) orig_mult = originality_mult(x) # we assume the reward is <=1, first term will dominate for reward <0, second for 0 < reward < 1 # reward = np.minimum(rwd/orig_mult, np.power(np.abs(rwd),1/orig_mult)) reward = np.array([ apply_originality_penalty(x, om) for x, om in zip(rwd, orig_mult) ]) out = reward + discrim_wt * p * orig_mult return out def adj_reward_old(x): p = discriminator_reward_mult(x) w = sigmoid(-(p - p_thresh) / 0.01) if randomize_reward: rand = np.random.uniform(size=p.shape) w *= rand reward = np.maximum(reward_fun_on(x), p_thresh) weighted_reward = w * p + (1 - w) * reward out = weighted_reward * originality_mult(x) # return out if EPOCHS is not None: settings['EPOCHS'] = EPOCHS if BATCH_SIZE is not None: settings['BATCH_SIZE'] = BATCH_SIZE task = SequenceGenerationTask(molecules=molecules, grammar=grammar, reward_fun=adj_reward, batch_size=BATCH_SIZE, max_steps=max_steps, save_dataset=None) node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0), eps=eps) rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0), eps=eps) model = get_decoder(molecules, grammar, z_size=settings['z_size'], decoder_hidden_n=200, feature_len=codec.feature_len(), max_seq_length=max_steps, drop_rate=drop_rate, batch_size=BATCH_SIZE, decoder_type=decoder_type, reward_fun=adj_reward, task=task, node_policy=node_policy, rule_policy=rule_policy, priors=priors)[0] if preload_file_root_name is not None: try: preload_path = full_path(gen_preload_file) model.load_state_dict(torch.load(preload_path, map_location='cpu'), strict=False) print('Generator weights loaded successfully!') except Exception as e: print('failed to load generator weights ' + str(e)) anchor_model = None from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer import numpy as np scorer = NormalizedScorer() if reward_fun_off is None: reward_fun_off = reward_fun_on # construct the loader to feed the discriminator def make_callback(data): def hc(inputs, model, outputs, loss_fn, loss): graphs = outputs['graphs'] smiles = [g.to_smiles() for g in graphs] for s in smiles: # only store unique instances of molecules so discriminator can't guess on frequency if s not in data: data.append(s) return hc class TemperatureCallback: def __init__(self, policy, temperature_function): self.policy = policy self.counter = 0 self.temp_fun = temperature_function def __call__(self, inputs, model, outputs, loss_fn, loss): self.counter += 1 target_temp = self.temp_fun(self.counter) self.policy.set_temperature(target_temp) # need to have something there to begin with, else the DataLoader constructor barfs def get_rl_fitter(model, loss_obj, train_gen, save_path, fit_plot_prefix='', model_process_fun=None, lr=None, extra_callbacks=[], loss_display_cap=float('inf'), anchor_model=None, anchor_weight=0): nice_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4) scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99) if dashboard is not None: metric_monitor = MetricPlotter( plot_prefix=fit_plot_prefix, loss_display_cap=loss_display_cap, dashboard_name=dashboard, plot_ignore_initial=plot_ignore_initial, process_model_fun=model_process_fun, smooth_weight=reward_sm) else: metric_monitor = None checkpointer = Checkpointer(valid_batches_to_checkpoint=1, save_path=save_path, save_always=True) fitter = fit_rl(train_gen=train_gen, model=model, optimizer=optimizer, scheduler=scheduler, epochs=EPOCHS, loss_fn=loss_obj, grad_clip=5, half_float=half_float, anchor_model=anchor_model, anchor_weight=anchor_weight, callbacks=[metric_monitor, checkpointer] + extra_callbacks) return fitter def my_gen(): for _ in range(1000): yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size'])) # the on-policy fitter gen_extra_callbacks = [make_callback(d) for d in history_data] if smiles_save_file is not None: smiles_save_path = os.path.realpath(root_location + 'pretrained/' + smiles_save_file) gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True)) print('Saved SMILES to {}'.format(smiles_save_file)) if node_temperature_schedule is not None: gen_extra_callbacks.append( TemperatureCallback(node_policy, node_temperature_schedule)) if rule_temperature_schedule is not None: gen_extra_callbacks.append( TemperatureCallback(rule_policy, rule_temperature_schedule)) fitter1 = get_rl_fitter( model, PolicyGradientLoss(on_policy_loss_type), # last_reward_wgt=reward_sm), GeneratorToIterable(my_gen), full_path(gen_save_file), plot_prefix + 'on-policy', model_process_fun=model_process_fun, lr=lr_on, extra_callbacks=gen_extra_callbacks, anchor_model=anchor_model, anchor_weight=anchor_weight) # # # get existing molecule data to add training pre_dataset = EvenlyBlendedDataset( 2 * [history_data[0]] + history_data[1:], labels=False) # a blend of 3 time horizons dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True) discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50) class MyLoss(nn.Module): def __init__(self): super().__init__() self.celoss = nn.CrossEntropyLoss() def forward(self, x): # tmp = discriminator_reward_mult(x['smiles']) # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy() # import numpy as np # assert np.max(np.abs(tmp-tmp2)) < 1e-6 return self.celoss(x['p_zinc'].to(device), x['dataset_index'].to(device)) fitter2 = get_rl_fitter(discrim_model, MyLoss(), IterableTransform( discrim_loader, lambda x: { 'smiles': x['X'], 'dataset_index': x['dataset_index'] }), full_path(disc_save_file), plot_prefix + ' discriminator', lr=lr_discrim, model_process_fun=None) def on_policy_gen(fitter, model): while True: # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies()) yield next(fitter) return model, fitter1, fitter2 #,on_policy_gen(fitter1, model)
def test_custom_grammar_mask_gen(self): molecules = True grammar = 'new' codec = get_codec(molecules, grammar, max_seq_length) self.generate_and_validate(codec)
def check_codec(self, input, molecules, grammar): settings = get_settings(molecules, grammar) codec = get_codec(molecules, grammar, settings['max_seq_length']) actions = codec.strings_to_actions([input]) re_input = codec.actions_to_strings(actions) self.assertEqual(input, re_input[0])
def get_decoder( molecules=True, grammar=True, z_size=200, decoder_hidden_n=200, feature_len=12, # TODO: remove this max_seq_length=15, drop_rate=0.0, decoder_type='step', task=None, node_policy=None, rule_policy=None, reward_fun=lambda x: -1 * np.ones(len(x)), batch_size=None, priors=True): codec = get_codec(molecules, grammar, max_seq_length) if decoder_type == 'old': stepper = ResettingRNNDecoder(z_size=z_size, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, steps=max_seq_length, drop_rate=drop_rate) stepper = OneStepDecoderContinuous(stepper) elif 'graph' in decoder_type and decoder_type not in [ 'attn_graph', 'rnn_graph' ]: return get_node_decoder(grammar, max_seq_length, drop_rate, decoder_type, rule_policy, reward_fun, batch_size, priors) elif decoder_type in ['attn_graph', 'rnn_graph']: # deprecated assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type" if 'attn' in decoder_type: encoder = GraphEncoder(grammar=codec.grammar, d_model=512, drop_rate=drop_rate, model_type='transformer') elif 'rnn' in decoder_type: encoder = GraphEncoder(grammar=codec.grammar, d_model=512, drop_rate=drop_rate, model_type='rnn') model = MultipleOutputHead( model=encoder, output_spec={ 'node': 1, # to be used to select next node to expand 'action': codec.feature_len() }, # to select the action for chosen node drop_rate=drop_rate) # don't support using this model in VAE-style models yet model.init_encoder_output = lambda x: None mask_gen = HypergraphMaskGenerator(max_len=max_seq_length, grammar=codec.grammar) mask_gen.priors = priors # bias=codec.grammar.get_log_frequencies()) if node_policy is None: node_policy = SoftmaxRandomSamplePolicy() if rule_policy is None: rule_policy = SoftmaxRandomSamplePolicy() if 'node' in decoder_type: stepper = GraphDecoderWithNodeSelection(model, node_policy=node_policy, rule_policy=rule_policy) env = GraphEnvironment(mask_gen, reward_fun=reward_fun, batch_size=batch_size) decoder = DecoderWithEnvironmentNew(stepper, env) else: stepper = GraphDecoder(model=model, mask_gen=mask_gen) decoder = to_gpu( SimpleDiscreteDecoderWithEnv(stepper, rule_policy, task=task, batch_size=batch_size)) return decoder, stepper else: if decoder_type == 'step': stepper = SimpleRNNDecoder(z_size=z_size, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, drop_rate=drop_rate, use_last_action=False) elif decoder_type == 'action': stepper = SimpleRNNDecoder( z_size=z_size, # + feature_len, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, drop_rate=drop_rate, use_last_action=True) elif decoder_type == 'action_resnet': stepper = ResNetRNNDecoder( z_size=z_size, # + feature_len, hidden_n=decoder_hidden_n, feature_len=codec.feature_len(), max_seq_length=max_seq_length, drop_rate=drop_rate, use_last_action=True) elif decoder_type == 'attention': stepper = SelfAttentionDecoderStep(num_actions=codec.feature_len(), max_seq_len=max_seq_length, drop_rate=drop_rate, enc_output_size=z_size) elif decoder_type == 'random': stepper = RandomDecoder(feature_len=codec.feature_len(), max_seq_length=max_seq_length) else: raise NotImplementedError('Unknown decoder type: ' + str(decoder_type)) if grammar is not False and '_graph' not in decoder_type: # add a masking layer mask_gen = get_codec(molecules, grammar, max_seq_length).mask_gen stepper = MaskingHead(stepper, mask_gen) policy = SoftmaxRandomSamplePolicy( ) # bias=codec.grammar.get_log_frequencies()) decoder = to_gpu( SimpleDiscreteDecoderWithEnv( stepper, policy, task=task, batch_size=batch_size)) # , bypass_actions=True)) return decoder, stepper
def train_policy_gradient(molecules=True, grammar=True, EPOCHS=None, BATCH_SIZE=None, reward_fun_on=None, reward_fun_off=None, max_steps=277, lr_on=2e-4, lr_discrim=1e-4, p_thresh=0.5, drop_rate=0.0, plot_ignore_initial=0, randomize_reward=False, save_file=None, reward_sm=0.0, preload_file=None, anchor_file=None, anchor_weight=0.0, decoder_type='action', plot_prefix='', dashboard='policy gradient', smiles_save_file=None, on_policy_loss_type='best', off_policy_loss_type='mean', sanity_checks=True): root_location = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) root_location = root_location + '/../../' gen_save_path = root_location + 'pretrained/gen_' + save_file disc_save_path = root_location + 'pretrained/disc_' + save_file if smiles_save_file is not None: smiles_save_path = root_location + 'pretrained/' + smiles_save_file save_dataset = IncrementingHDF5Dataset(smiles_save_path) else: save_dataset = None settings = get_settings(molecules=molecules, grammar=grammar) codec = get_codec(molecules, grammar, settings['max_seq_length']) discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate) zinc_data = get_zinc_smiles() zinc_set = set(zinc_data) lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE] history_data = [deque(['O'], maxlen=lb) for lb in lookbacks] def originality_mult(smiles_list): out = [] for s in smiles_list: if s in zinc_set: out.append(0.5) elif s in history_data[0]: out.append(0.5) elif s in history_data[1]: out.append(0.70) elif s in history_data[2]: out.append(0.85) else: out.append(1.0) return np.array(out) def sigmoid(x): tmp = -x #( return 1 / (1 + np.exp(-x)) def discriminator_reward_mult(smiles_list): orig_state = discrim_model.training discrim_model.eval() discrim_out_logits = discrim_model(smiles_list)['p_zinc'] discrim_probs = F.softmax(discrim_out_logits, dim=1) prob_zinc = discrim_probs[:, 1].detach().cpu().numpy() if orig_state: discrim_model.train() return prob_zinc def adj_reward(x): p = discriminator_reward_mult(x) reward = np.maximum(reward_fun_on(x), 0) out = reward * originality_mult(x) + 2 * p return out def adj_reward_old(x): p = discriminator_reward_mult(x) w = sigmoid(-(p - p_thresh) / 0.01) if randomize_reward: rand = np.random.uniform(size=p.shape) w *= rand reward = np.maximum(reward_fun_on(x), p_thresh) weighted_reward = w * p + (1 - w) * reward out = weighted_reward * originality_mult(x) # return out if EPOCHS is not None: settings['EPOCHS'] = EPOCHS if BATCH_SIZE is not None: settings['BATCH_SIZE'] = BATCH_SIZE task = SequenceGenerationTask(molecules=molecules, grammar=grammar, reward_fun=adj_reward, batch_size=BATCH_SIZE, max_steps=max_steps, save_dataset=save_dataset) model = get_decoder(molecules, grammar, z_size=settings['z_size'], decoder_hidden_n=200, feature_len=codec.feature_len(), max_seq_length=max_steps, drop_rate=drop_rate, decoder_type=decoder_type, task=task)[0] # TODO: really ugly, refactor! In fact this model doesn't need a MaskingHead at all! model.stepper.model.mask_gen.priors = True #'conditional' # use empirical priors for the mask gen # if preload_file is not None: # try: # preload_path = root_location + 'pretrained/' + preload_file # model.load_state_dict(torch.load(preload_path)) # except: # pass anchor_model = None from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer import rdkit.Chem.rdMolDescriptors as desc import numpy as np scorer = NormalizedScorer() def model_process_fun(model_out, visdom, n): # TODO: rephrase this to return a dict, instead of calling visdom directly from rdkit import Chem from rdkit.Chem.Draw import MolToFile # actions, logits, rewards, terminals, info = model_out smiles, valid = model_out['info'] total_rewards = model_out['rewards'].sum(1) best_ind = torch.argmax(total_rewards).data.item() this_smile = smiles[best_ind] mol = Chem.MolFromSmiles(this_smile) pic_save_path = root_location + 'images/' + 'tmp.svg' if mol is not None: try: MolToFile(mol, pic_save_path, imageType='svg') with open(pic_save_path, 'r') as myfile: data = myfile.read() data = data.replace('svg:', '') visdom.append('best molecule of batch', 'svg', svgstr=data) except Exception as e: print(e) scores, norm_scores = scorer.get_scores([this_smile]) visdom.append( 'score component', 'line', X=np.array([n]), Y=np.array( [[x for x in norm_scores[0]] + [norm_scores[0].sum()] + [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)] + [total_rewards[best_ind].item()]]), opts={ 'legend': [ 'logP', 'SA', 'cycle', 'norm_reward', 'reward', 'Aromatic rings', 'eff_reward' ] }) visdom.append('fraction valid', 'line', X=np.array([n]), Y=np.array([valid.mean().data.item()])) if reward_fun_off is None: reward_fun_off = reward_fun_on # construct the loader to feed the discriminator def make_callback(data): def hc(inputs, model, outputs, loss_fn, loss): graphs = outputs['graphs'] smiles = [g.to_smiles() for g in graphs] for s in smiles: # only store unique instances of molecules so discriminator can't guess on frequency if s not in data: data.append(s) return hc # need to have something there to begin with, else the DataLoader constructor barfs def get_rl_fitter(model, loss_obj, train_gen, save_path, fit_plot_prefix='', model_process_fun=None, lr=None, extra_callbacks=[], loss_display_cap=float('inf'), anchor_model=None, anchor_weight=0): nice_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam(nice_params, lr=lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99) if dashboard is not None: metric_monitor = MetricPlotter( plot_prefix=fit_plot_prefix, loss_display_cap=loss_display_cap, dashboard_name=dashboard, plot_ignore_initial=plot_ignore_initial, process_model_fun=model_process_fun, smooth_weight=0.9) else: metric_monitor = None checkpointer = Checkpointer(valid_batches_to_checkpoint=1, save_path=save_path, save_always=True) fitter = fit_rl(train_gen=train_gen, model=model, optimizer=optimizer, scheduler=scheduler, epochs=EPOCHS, loss_fn=loss_obj, grad_clip=5, anchor_model=anchor_model, anchor_weight=anchor_weight, callbacks=[metric_monitor, checkpointer] + extra_callbacks) return fitter class GeneratorToIterable: def __init__(self, gen): self.gen = gen # we assume the generator is finite self.len = 0 for _ in gen(): self.len += 1 def __len__(self): return self.len def __iter__(self): return self.gen() def my_gen(): for _ in range(1000): yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size'])) # the on-policy fitter history_callbacks = [make_callback(d) for d in history_data] fitter1 = get_rl_fitter(model, PolicyGradientLoss(on_policy_loss_type, last_reward_wgt=reward_sm), GeneratorToIterable(my_gen), gen_save_path, plot_prefix + 'on-policy', model_process_fun=model_process_fun, lr=lr_on, extra_callbacks=history_callbacks, anchor_model=anchor_model, anchor_weight=anchor_weight) # # # get existing molecule data to add training pre_dataset = EvenlyBlendedDataset( 2 * [history_data[0]] + history_data[1:], labels=False) # a blend of 3 time horizons dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True) discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50) celoss = nn.CrossEntropyLoss() def my_loss(x): # tmp = discriminator_reward_mult(x['smiles']) # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy() # import numpy as np # assert np.max(np.abs(tmp-tmp2)) < 1e-6 return celoss(x['p_zinc'].to(device), x['dataset_index'].to(device)) fitter2 = get_rl_fitter(discrim_model, my_loss, IterableTransform( discrim_loader, lambda x: { 'smiles': x['X'], 'dataset_index': x['dataset_index'] }), disc_save_path, plot_prefix + ' discriminator', lr=lr_discrim, model_process_fun=None) def on_policy_gen(fitter, model): while True: model.policy = SoftmaxRandomSamplePolicy( ) #bias=codec.grammar.get_log_frequencies()) yield next(fitter) return model, on_policy_gen(fitter1, model), fitter2