Beispiel #1
0
 def init_visdom(self):
     try:
         from generative_playground.utils.visdom_helper import Dashboard
         if self.dashboard_name is not None:
             self.vis = Dashboard(
                 self.dashboard_name,
                 call_every=10 if self.frequent_calls else 1)
             self.have_visdom = True
         else:
             self.vis = None
             self.have_visdom = False
     except:
         self.have_visdom = False
         self.vis = None
Beispiel #2
0
def a2c_sequence(name = 'a2c_sequence', task=None, body=None):
    config = Config()
    config.num_workers = batch_size # same thing as batch size
    config.task_fn = lambda: task
    config.optimizer_fn = lambda params: torch.optim.RMSprop(params, lr=0.0007)
    config.network_fn = lambda state_dim, action_dim: \
                            to_gpu(CategoricalActorCriticNet(state_dim,
                                                      action_dim,
                                                      body,
                                                      gpu=0,
                                                      mask_gen=mask_gen))
    #config.policy_fn = SamplePolicy # not used
    config.state_normalizer = lambda x: x
    config.reward_normalizer = lambda x: x
    config.discount = 0.99
    config.use_gae = False #TODO: for now, MUST be false as our RNN network isn't com
    config.gae_tau = 0.97
    config.entropy_weight = 0.01
    config.rollout_length = 5
    config.gradient_clip = 0.5
    config.logger = logging.getLogger()#get_logger(file_name='deep_rl_a2c', skip=True)
    config.logger.info('test')
    config.iteration_log_interval
    config.max_steps = 100000
    dash_name = 'DeepRL'
    visdom = Dashboard(dash_name)
    run_iterations(MyA2CAgent(config), visdom, invalid_value=invalid_value)
Beispiel #3
0
    def __init__(self,
                 plot_prefix='',
                 save_file=None,
                 loss_display_cap=4,
                 dashboard_name=None,
                 plot_ignore_initial=0,
                 process_model_fun=None,
                 extra_metric_fun=None,
                 smooth_weight=0.0,
                 frequent_calls=False):
        self.plot_prefix = plot_prefix
        self.process_model_fun = process_model_fun
        self.extra_metric_fun = extra_metric_fun
        if save_file is not None:
            self.save_file = save_file
        else:
            self.save_file = plot_prefix.replace(' ', '_') + '.pkz'

        my_location = os.path.dirname(
            os.path.abspath(inspect.getfile(inspect.currentframe())))
        root_location = my_location + '/../'
        # TODO: pass that through in the constructor!
        self.save_file = root_location + 'molecules/train/vae/data/' + self.save_file
        self.plot_ignore_initial = plot_ignore_initial
        self.loss_display_cap = loss_display_cap
        self.plot_counter = 0
        self.stats = pd.DataFrame(
            columns=['batch', 'timestamp', 'gpu_usage', 'train', 'loss'])
        self.smooth_weight = smooth_weight
        self.smooth = {}
        try:
            from generative_playground.utils.visdom_helper import Dashboard
            if dashboard_name is not None:
                self.vis = Dashboard(dashboard_name,
                                     call_every=10 if frequent_calls else 1)
                self.have_visdom = True
            else:
                self.vis = None
                self.have_visdom = False
        except:
            self.have_visdom = False
            self.vis = None
Beispiel #4
0
class MetricPlotter:
    def __init__(self,
                 plot_prefix='',
                 save_file=None,
                 loss_display_cap=4,
                 dashboard_name='main',
                 save_location=None,
                 plot_ignore_initial=0,
                 process_model_fun=None,
                 extra_metric_fun=None,
                 smooth_weight=0.0,
                 frequent_calls=False):
        self.plot_prefix = plot_prefix
        self.process_model_fun = process_model_fun
        self.extra_metric_fun = extra_metric_fun

        if save_location is not None:
            if save_file is None:
                save_file = '/' + dashboard_name + '_metrics.zip'
            self.save_file = save_location + save_file
        else:
            self.save_file = None
        self.plot_ignore_initial = plot_ignore_initial
        self.loss_display_cap = loss_display_cap
        self.plot_counter = 0
        self.reward_calc = MetricStats()
        self.stats = pd.DataFrame(
            columns=['batch', 'timestamp', 'gpu_usage', 'train', 'loss'])
        self.smooth_weight = smooth_weight
        self.smooth = {}
        self.last_timestamp = datetime.datetime.now()
        self.dashboard_name = dashboard_name.replace(
            '#', ':'
        )  # for some reason, Visdom doesn't like hashes in dashboard names
        self.frequent_calls = frequent_calls
        self.vis = None
        self.have_visdom = False
        self.init_visdom()

    def init_visdom(self):
        try:
            from generative_playground.utils.visdom_helper import Dashboard
            if self.dashboard_name is not None:
                self.vis = Dashboard(
                    self.dashboard_name,
                    call_every=10 if self.frequent_calls else 1)
                self.have_visdom = True
            else:
                self.vis = None
                self.have_visdom = False
        except:
            self.have_visdom = False
            self.vis = None

    def __getstate__(self):
        state = {
            key: value
            for key, value in self.__dict__.items() if key != 'vis'
        }
        state['plots'] = self.vis.plots
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.init_visdom()
        self.vis.plots = state['plots']
        self.plot_counter += 1

    def __call__(
            self,
            _,  #inputs
            model,
            outputs,
            loss_fn,
            loss):
        '''
        Plot the results of the latest batch
        :param train: bool: was this a traning batch?
        :param loss: float: latest loss
        :param metrics: dict {str:float} with any additional metrics
        :return: None
        '''
        print('calling metric monitor...')
        try:
            train = model.training
        except:
            train = True

        if hasattr(loss, 'device'):
            loss = loss.data.item()

        metrics_from_loss = loss_fn.metrics if hasattr(loss_fn,
                                                       'metrics') else None
        model_out = outputs

        if train:
            loss_name = self.plot_prefix + ' train_loss'
        else:
            loss_name = self.plot_prefix + ' val_loss'

        # show intermediate results
        gpu_usage = get_gpu_memory_map()
        cpu_usage = get_free_ram()
        print(loss_name, loss, self.plot_counter, gpu_usage)
        self.plot_counter += 1
        if self.vis is not None and self.plot_counter > self.plot_ignore_initial and self.have_visdom:
            all_metrics = {}
            # if not train: # don't want to call it too often as it takes time
            all_metrics['memory_usage'] = self.entry_from_dict({
                'gpu':
                gpu_usage[0],
                'cpu':
                cpu_usage['used']
            })
            #   {'type':'line',
            # 'X': np.array([self.plot_counter]),
            # 'Y':np.array([gpu_usage[0]])}
            if loss is not None:
                all_metrics[loss_name] = {
                    'type': 'line',
                    'X': np.array([self.plot_counter]),
                    'Y': np.array([min(self.loss_display_cap, loss)]),
                    'smooth': self.smooth_weight
                }

            if metrics_from_loss is not None and len(metrics_from_loss) > 0:
                for key, value in metrics_from_loss.items():
                    if type(value) == dict:  # dict of dicts, so multiple plots
                        all_metrics[key] = self.entry_from_dict(value)
                    else:  # just one dict with data, old-style
                        all_metrics[loss_name +
                                    ' metrics'] = self.entry_from_dict(
                                        metrics_from_loss)
                        break
            now = datetime.datetime.now()
            if self.last_timestamp is not None:
                batch_duration = (now - self.last_timestamp).total_seconds()
                all_metrics['seconds_per_batch'] = {
                    'type': 'line',
                    'X': np.array([self.plot_counter]),
                    'Y': np.array([batch_duration])
                }
            self.last_timestamp = now

            try:
                smiles = outputs['info'][0]
            except:
                smiles = None

            try:
                rewards = outputs['rewards']
                if len(rewards.shape) == 2:
                    rewards = rewards.sum(1)
                rewards = to_numpy(rewards)
                reward_dict = self.reward_calc(rewards, smiles)
                all_metrics['reward_stats'] = self.entry_from_dict(reward_dict)
            except:
                rewards = np.array([0])
            # if self.extra_metric_fun is not None:
            #     all_metrics.update(self.extra_metric_fun(inputs, targets, model_out, train, self.plot_counter))

            # now do the smooth:
            smoothed_metrics = {}
            for title, metric in all_metrics.items():
                if 'smooth' not in metric:
                    self.smooth[title] = metric
                else:
                    self.smooth[title] = smooth_data(
                        self.smooth[title] if title in self.smooth else metric,
                        metric, metric['smooth'])

            self.vis.plot_metric_dict({
                title: value
                for title, value in self.smooth.items()
                if title in all_metrics.keys()
            })

            # TODO: factor this out
            if self.process_model_fun is not None:
                self.process_model_fun(model_out, self.vis, self.plot_counter)

        metrics_from_loss = {} if metrics_from_loss is None else copy.copy(
            metrics_from_loss)
        metrics_from_loss['train'] = train
        metrics_from_loss['gpu_usage'] = gpu_usage[0]
        metrics_from_loss['loss'] = loss
        metrics_from_loss['batch'] = self.plot_counter
        metrics_from_loss['timestamp'] = datetime.datetime.now()
        metrics_from_loss['best_reward'] = rewards.max()

        self.stats = self.stats.append(metrics_from_loss, ignore_index=True)

        if self.save_file is not None:
            with gzip.open(self.save_file, 'wb') as f:
                pickle.dump(self.stats, f)

    def entry_from_dict(self, metrics):
        return {
            'type': 'line',
            'X': np.array([self.plot_counter]),
            'Y': np.array([[val for key, val in metrics.items()]]),
            'opts': {
                'legend': [key for key, val in metrics.items()]
            },
            'smooth': self.smooth_weight
        }
def run_genetic_opt(
        top_N=10,
        p_mutate=0.2,
        mutate_num_best=64,
        mutate_use_total_probs=False,
        p_crossover=0.2,
        num_batches=100,
        batch_size=30,
        snapshot_dir=None,
        entropy_wgt=0.0,
        root_name=None,
        obj_num=None,
        ver='v2',
        lr=0.01,
        num_runs=100,
        num_explore=5,
        plot_single_runs=True,
        steps_with_no_improvement=10,
        reward_aggregation=np.median,
        attempt='',  # only used for disambiguating plotting
        max_steps=90,
        past_runs_graph_file=None):

    manager = mp.Manager()
    queue = manager.Queue()

    relationships = nx.DiGraph()
    grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  # 'hyper_grammar.pickle'
    grammar = 'hypergraph:' + grammar_cache

    reward_funs = guacamol_goal_scoring_functions(ver)
    reward_fun = reward_funs[obj_num]

    split_name = root_name.split('_')
    split_name[0] += 'Stats'
    dash_name = '_'.join(split_name) + attempt
    vis = Dashboard(dash_name, call_every=1)

    first_runner_factory = lambda: PolicyGradientRunner(
        grammar,
        BATCH_SIZE=batch_size,
        reward_fun=reward_fun,
        max_steps=max_steps,
        num_batches=num_batches,
        lr=lr,
        entropy_wgt=entropy_wgt,
        # lr_schedule=shifted_cosine_schedule,
        root_name=root_name,
        preload_file_root_name=None,
        plot_metrics=plot_single_runs,
        save_location=snapshot_dir,
        metric_smooth=0.0,
        decoder_type='graph_conditional_sparse',
        # 'graph_conditional',  # 'rnn_graph',# 'attention',
        on_policy_loss_type='advantage_record',
        rule_temperature_schedule=None,
        # lambda x: toothy_exp_schedule(x, scale=num_batches),
        eps=0.0,
        priors='conditional',
    )

    init_thresh = 50
    pca_dim = 10
    if past_runs_graph_file:
        params, rewards = extract_params_rewards(past_runs_graph_file)
        sampler = ParameterSampler(params,
                                   rewards,
                                   init_thresh=init_thresh,
                                   pca_dim=pca_dim)
    else:
        sampler = None
    data_cache = {}
    best_so_far = float('-inf')
    steps_since_best = 0

    initial = True
    should_stop = False
    run = 0

    with mp.Pool(4) as p:
        while not should_stop:
            data_cache = populate_data_cache(snapshot_dir, data_cache)
            if run < num_explore:
                model = first_runner_factory()
                if sampler:
                    model.params = sampler.sample()
            else:
                model = (pick_model_to_run(data_cache,
                                           PolicyGradientRunner,
                                           snapshot_dir,
                                           num_best=top_N)
                         if data_cache else first_runner_factory())

            orig_name = model.root_name
            model.set_root_name(generate_root_name(orig_name, data_cache))

            if run > num_explore:
                relationships.add_edge(orig_name, model.root_name)

                if random.random() < p_crossover and len(data_cache) > 1:
                    second_model = pick_model_for_crossover(
                        data_cache, model, PolicyGradientRunner, snapshot_dir)
                    model = classic_crossover(model, second_model)
                    relationships.add_edge(second_model.root_name,
                                           model.root_name)

                if random.random() < p_mutate:
                    model = mutate(model,
                                   pick_best=mutate_num_best,
                                   total_probs=mutate_use_total_probs)
                    relationships.node[model.root_name]['mutated'] = True
                else:
                    relationships.node[model.root_name]['mutated'] = False

                with open(
                        snapshot_dir + '/' + model.root_name + '_lineage.pkl',
                        'wb') as f:
                    pickle.dump(relationships, f)

            model.save()

            if initial is True:
                for _ in range(4):
                    print('Starting {}'.format(run))
                    p.apply_async(run_model,
                                  (queue, model.root_name, run, snapshot_dir))
                    run += 1
                initial = False
            else:
                print('Starting {}'.format(run))
                p.apply_async(run_model,
                              (queue, model.root_name, run, snapshot_dir))
                run += 1

            finished_run, finished_root_name = queue.get(block=True)
            print('Finished: {}'.format(finished_root_name))

            data_cache = populate_data_cache(snapshot_dir, data_cache)
            my_rewards = data_cache[finished_root_name]['best_rewards']
            metrics = {
                'max': my_rewards.max(),
                'median': np.median(my_rewards),
                'min': my_rewards.min()
            }
            metric_dict = {
                'type': 'line',
                'X': np.array([finished_run]),
                'Y': np.array([[val for key, val in metrics.items()]]),
                'opts': {
                    'legend': [key for key, val in metrics.items()]
                }
            }

            vis.plot_metric_dict({'worker rewards': metric_dict})

            this_agg_reward = reward_aggregation(my_rewards)
            if this_agg_reward > best_so_far:
                best_so_far = this_agg_reward
                steps_since_best = 0
            else:
                steps_since_best += 1

            should_stop = (
                steps_since_best >= steps_with_no_improvement
                and finished_run > num_explore + steps_with_no_improvement)

        p.terminate()

    return extract_best(data_cache, 1)
class MetricPlotter:
    def __init__(self,
                 plot_prefix='',
                 save_file=None,
                 loss_display_cap=4,
                 dashboard_name=None,
                 plot_ignore_initial=0,
                 process_model_fun=None,
                 extra_metric_fun=None,
                 smooth_weight=0.0,
                 frequent_calls=False):
        self.plot_prefix = plot_prefix
        self.process_model_fun = process_model_fun
        self.extra_metric_fun = extra_metric_fun
        if save_file is not None:
            self.save_file = save_file
        else:
            self.save_file = plot_prefix.replace(' ','_') + '.pkz'

        my_location = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
        root_location = my_location + '/../'
        # TODO: pass that through in the constructor!
        self.save_file = root_location + 'molecules/train/vae/data/' + self.save_file
        self.plot_ignore_initial = plot_ignore_initial
        self.loss_display_cap = loss_display_cap
        self.plot_counter = 0
        self.stats = pd.DataFrame(columns=['batch', 'timestamp', 'gpu_usage', 'train', 'loss'])
        self.smooth_weight = smooth_weight
        self.smooth = {}
        try:
            from generative_playground.utils.visdom_helper import Dashboard
            if dashboard_name is not None:
                self.vis = Dashboard(dashboard_name,
                                     call_every=10 if frequent_calls else 1)
                self.have_visdom = True
        except:
            self.have_visdom = False
            self.vis = None

    def __call__(self,
                 inputs,
                 model,
                 outputs,
                 loss_fn,
                 loss):
                 # train,
                 # loss,
                 # metrics=None,
                 # model_out=None,
                 # inputs=None,
                 # targets=None):
        '''
        Plot the results of the latest batch
        :param train: bool: was this a traning batch?
        :param loss: float: latest loss
        :param metrics: dict {str:float} with any additional metrics
        :return: None
        '''
        print('calling metric monitor...')
        train = model.training
        loss = loss.data.item()
        metrics = loss_fn.metrics if hasattr(loss_fn, 'metrics') else None
        model_out = outputs

        if train:
            loss_name = self.plot_prefix + ' train_loss'
        else:
            loss_name = self.plot_prefix + ' val_loss'

        # show intermediate results
        gpu_usage = get_gpu_memory_map()
        print(loss_name, loss, self.plot_counter, gpu_usage)
        self.plot_counter += 1
        if self.vis is not None and self.plot_counter > self.plot_ignore_initial and self.have_visdom:
            all_metrics = {}
            # if not train: # don't want to call it too often as it takes time
            all_metrics['gpu_usage'] ={'type':'line',
                            'X': np.array([self.plot_counter]),
                            'Y':np.array([gpu_usage[0]])}
            all_metrics[loss_name] ={'type': 'line',
                            'X': np.array([self.plot_counter]),
                            'Y': np.array([min(self.loss_display_cap, loss)]),
                                'smooth':self.smooth_weight}
            if metrics is not None and len(metrics) > 0:
                all_metrics[loss_name + ' metrics']={'type':'line',
                            'X': np.array([self.plot_counter]),
                            'Y': np.array([[val for key, val in metrics.items()]]),
                            'opts':{'legend': [key for key, val in metrics.items()]},
                                    'smooth': self.smooth_weight}

            if self.extra_metric_fun is not None:
                all_metrics.update(self.extra_metric_fun(inputs, targets, model_out, train, self.plot_counter))


            # now do the smooth:
            smoothed_metrics = {}
            for title, metric in all_metrics.items():
                if title not in self.smooth or 'smooth' not in metric:
                    self.smooth[title] = metric
                else:
                    self.smooth[title] = smooth_data(self.smooth[title], metric, metric['smooth'])

            self.vis.plot_metric_dict({title:value for title, value in self.smooth.items() if title in all_metrics.keys()})

            # TODO: factor this out
            if self.process_model_fun is not None:
                self.process_model_fun(model_out, self.vis, self.plot_counter)



        metrics =  {} if metrics is None else copy.copy(metrics)
        metrics['train'] = train
        metrics['gpu_usage'] = gpu_usage[0]
        metrics['loss'] = loss
        metrics['batch'] = self.plot_counter
        metrics['timestamp'] = datetime.datetime.now()

        self.stats = self.stats.append(metrics, ignore_index=True)

        if True:#not train: # only save to disk during valdation calls for speedup
            with gzip.open(self.save_file,'wb') as f:
                pickle.dump(self.stats, f)
from generative_playground.utils.visdom_helper import Dashboard
from generative_playground.molecules.train.vae.main_train_vae import train_vae
from generative_playground.molecules.train import train_validity
from generative_playground.molecules.model_settings import get_settings
from generative_playground.codec.grammar_codec import ZincGrammarModel
from generative_playground.molecules.rdkit_utils.rdkit_utils import fraction_valid
from generative_playground.models.simple_models import DenseHead
import numpy as np
from generative_playground.utils.gpu_utils import to_gpu

molecules = True
grammar = True
settings = get_settings(molecules, grammar)

dash_name = 'test'
visdom = Dashboard(dash_name)
model, fitter, main_dataset = train_vae(molecules=True,
                                        grammar=True,
                                        BATCH_SIZE=150,
                                        drop_rate=0.3,
                                        sample_z=True,
                                        save_file='next_gen.h5',
                                        encoder_type=False,
                                        lr=5e-4,
                                        plot_prefix='RNN enc lr 1e-4',
                                        dashboard=dash_name,
                                        preload_weights=False)
# this is a wrapper for encoding/decodng
grammar_model = ZincGrammarModel(model=model)
validity_model = to_gpu(
    DenseHead(model.encoder, body_out_dim=settings['z_size'], drop_rate=0.3))
def fit(train_gen=None,
        valid_gen=None,
        model=None,
        optimizer=None,
        scheduler=None,
        epochs=None,
        loss_fn=None,
        save_path=None,
        dashboard=None,
        ignore_initial=10,
        exp_smooth=0.9,
        save_every=0):

    best_valid_loss = float('inf')

    if dashboard is not None:
        vis = Dashboard(dashboard)
    plot_counter = 0

    for epoch in range(epochs):
        print('epoch ', epoch)
        scheduler.step()
        for train, data_gen in [True, train_gen], [False, valid_gen]:
            loss_ = 0
            count_ = 0
            if train:
                model.train()
                loss_name = 'training_loss'
            else:
                model.eval()
                loss_name = 'validation_loss'

            for inputs_, targets_ in data_gen:
                inputs = to_variable(inputs_)
                targets = to_variable(targets_)
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                if train:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                try:
                    model.reset_hidden()
                except:
                    pass
                this_loss = loss.data[0]
                loss_ += this_loss
                count_ += 1
                plot_counter += 1
                if not train:
                    valid_loss = loss_ / count_
                    if count_ > 50:
                        break
                elif save_every and count_ > 0 and count_ % save_every == 0:
                    save_model(model)
                # show intermediate results
                print(loss_name, loss_ / count_, count_, get_gpu_memory_map())
                if dashboard is not None and plot_counter > ignore_initial:
                    try:
                        vis.append(loss_name,
                                   'line',
                                   X=np.array([plot_counter]),
                                   Y=np.array([this_loss]))
                    except:
                        print(
                            'Please start a visdom server with python -m visdom.server!'
                        )
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            print("we're improving!", best_valid_loss)
            # spell_out:
            save_model(model, save_path)

        if valid_loss < 1e-10:
            break