def launch_experiment(settings):
  u.dump_dict(settings, 'CONVNET EXPERIMENT PARAMETERS')

  metadata = d.Metadata(train_folder=settings['train_folder'], test_folder=settings['test_folder'],
                        train_metadata=settings['train_metadata'], test_metadata=settings['test_metadata'],
                        ngram_metadata=settings['ngram_metadata'], vocab=settings['vocab'], decay_step=settings['decay_step'],
                        batch_size=settings['batch_size'], create_mask=settings['create_mask'], subset=settings['subset'],
                        percent=settings['percent'], size_limits=settings['size_limits'], loss=settings['loss'])

  model = instanciate_model(enc_input_dim=settings['enc_input_dim'], dec_input_dim=settings['dec_input_dim'],
                            enc_max_seq_len=settings['enc_max_seq_len'], dec_max_seq_len=settings['dec_max_seq_len'],
                            enc_layers=settings['enc_layers'], dec_layers=settings['dec_layers'],
                            enc_kernel_size=settings['enc_kernel_size'], dec_kernel_size=settings['dec_kernel_size'],
                            emb_dim=settings['emb_dim'], hid_dim=settings['hid_dim'],
                            enc_dropout=settings['enc_dropout'], dec_dropout=settings['dec_dropout'],
                            output_size=metadata.output_size, reduce_dim=settings['reduce_dim'], 
                            device=metadata.device, pad_idx=metadata.pad_idx,
                            p_enc_layers=settings['p_enc_layers'], p_enc_kernel_size=settings['p_enc_kernel_size'],
                            p_enc_dropout=settings['p_enc_dropout'], p_dec_layers=settings['p_dec_layers'],
                            p_dec_kernel_size=settings['p_dec_kernel_size'], p_dec_dropout=settings['p_dec_dropout'])

  logging.info(f'The model has {u.count_trainable_parameters(model):,} trainable parameters')

  if settings['weight_decay'] == 0:
    optimizer = optim.Adam(model.parameters(), lr=settings['lr'])
  else:
    optimizer = opt.AdamW(model.parameters(), lr=settings['lr'], weight_decay=settings['weight_decay'])

  if settings['load_model']:
    u.load_model(model, f"{settings['save_path']}convnet_feedback.pt", map_location=None, restore_only_similars=True)

  memory_word_acc = 0

  for epoch in tqdm(range(settings['max_epochs'])):
    epoch_losses = train_pass(model, optimizer, metadata, settings, epoch)

    if epoch % settings['train_acc_step'] == 0: 
      threading.Thread(target=compute_plot_scores_async, args=(epoch, metadata, plotter)).start()

    metadata.loss.step(epoch)  # annealing kld loss if args.loss = 'both'

    plotter.line_plot('loss', 'train', 'Loss', epoch, epoch_losses)

    metadata.SM.reset_feed()

    if epoch % settings['eval_step'] == 0:
      _, eval_word_acc, _ = eval_model(model, metadata, settings, epoch, only_loss=False)
    else:
      eval_model(model, metadata, settings, epoch)
    
    if eval_word_acc > memory_word_acc:
      memory_word_acc = eval_word_acc
      u.save_checkpoint(model, optimizer, settings['save_path'] + 'convnet.pt')
Ejemplo n.º 2
0
def launch_experiment(settings):
    u.dump_dict(settings, 'TRANSFORMER EXPERIMENT PARAMETERS')

    metadata = d.Metadata(
        train_folder=settings['train_folder'],
        test_folder=settings['test_folder'],
        train_metadata=settings['train_metadata'],
        test_metadata=settings['test_metadata'],
        ngram_metadata=settings['ngram_metadata'],
        vocab=settings['vocab'],
        decay_step=settings['decay_step'],
        batch_size=settings['batch_size'],
        create_mask=settings['create_mask'],
        subset=settings['subset'],
        percent=settings['percent'],
        size_limits=settings['size_limits'],
        loss=settings['loss'],
        device=torch.device(f"cuda:{settings['num_device']}"))

    model = instanciate_model(settings, metadata)

    optimizer = opt.RAdam(model.parameters(), lr=settings['lr'])

    memory_word_acc = 0

    for epoch in tqdm(range(settings['max_epochs'])):
        epoch_losses = train_pass(model, optimizer, metadata, settings)

        plotter.line_plot('loss', 'train', 'Loss', epoch, epoch_losses)

        metadata.SM.reset_feed()

        if epoch % settings['eval_step'] == 0:
            _, eval_word_acc, _ = eval_model(model,
                                             metadata,
                                             settings,
                                             epoch,
                                             only_loss=False)
        else:
            eval_model(model, metadata, settings, epoch)

        if eval_word_acc > memory_word_acc:
            memory_word_acc = eval_word_acc
            u.save_checkpoint(model, optimizer,
                              settings['save_path'] + 'transformer.pt')
Ejemplo n.º 3
0
    def compile(self, elem):
        self.data["name"] = elem.get("name")
        self.data["uid"] = os.path.basename(os.path.dirname(
            self.compiler.path))

        for attr in ["repeat", "repeat_delay"]:
            val = elem.get(attr)

            if val:
                self.data[attr] = int(val)

        BaseTagCompiler.compile(self, elem)

        self.compiler.quests.write("{uid} = {d}",
                                   uid=self.data["uid"],
                                   d=utils.dump_dict(self.data))
Ejemplo n.º 4
0
    def __init__(self,
                 device=None,
                 logfile='_logs/_logs_CTC.txt',
                 save_metadata='_Data_metadata.pk',
                 batch_size=64,
                 lr=1e-4,
                 load_model=True,
                 n_epochs=500,
                 eval_step=1,
                 config={},
                 save_name_model='convnet/ctc_model.pt',
                 lr_scheduling=True,
                 lr_scheduler_type='plateau',
                 **kwargs):
        if not os.path.isdir(os.path.dirname(logfile)):
            os.makedirs(os.path.dirname(logfile))
        logging.basicConfig(filename=logfile,
                            filemode='a',
                            level=logging.INFO,
                            format='%(asctime)s - %(levelname)s - %(message)s')
        self.device = torch.device('cuda' if torch.cuda.is_available(
        ) else 'cpu') if device is None else device
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.eval_step = eval_step
        self.save_name_model = save_name_model
        self.lr_scheduling = lr_scheduling
        self.kwargs = kwargs

        self.set_metadata(save_metadata)
        self.set_data_loader()

        self.config = {
            'output_dim': len(self.idx_to_tokens),
            'emb_dim': 512,
            'd_model': 512,
            'n_heads': 8,
            'd_ff': 1024,
            'kernel_size': 3,
            'n_blocks': 6,
            'n_blocks_strided': 2,
            'dropout': 0.,
            'block_type': 'dilated'
        }
        self.config = {**self.config, **config}
        self.model = self.instanciate_model(**self.config)
        self.model = nn.DataParallel(self.model)

        u.dump_dict(
            {
                'batch_size': batch_size,
                'lr': lr,
                'lr_scheduling': lr_scheduling,
                'lr_scheduler_type': lr_scheduler_type,
                'load_model': load_model
            }, 'CTCModel Hyperparameters')
        u.dump_dict(self.config, 'CTCModel PARAMETERS')
        logging.info(self.model)
        logging.info(
            f'The model has {u.count_trainable_parameters(self.model):,} trainable parameters'
        )

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.CTCLoss(zero_infinity=True)

        if self.lr_scheduling:
            if lr_scheduler_type == 'cosine':
                self.lr_scheduler = CosineAnnealingWarmUpRestarts(
                    self.optimizer,
                    T_0=150,
                    T_mult=2,
                    eta_max=1e-2,
                    T_up=50,
                    gamma=0.5)
            else:
                patience, min_lr, threshold = kwargs.get(
                    'patience',
                    50), kwargs.get('min_lr',
                                    1e-5), kwargs.get('lr_threshold', 0.003)
                self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    self.optimizer,
                    mode='max',
                    factor=0.1,
                    patience=patience,
                    verbose=True,
                    min_lr=min_lr,
                    threshold_mode='abs',
                    threshold=threshold)

        if load_model:
            u.load_model(self.model,
                         self.save_name_model,
                         restore_only_similars=True)
Ejemplo n.º 5
0
        config['seed'] = seed
        output_dir = os.path.join(base_output_dir, 'seed' + str(seed))
        # directory removal warning!
        create_directory(output_dir, remove_curr=True)

        tree_model = DendroFCNN(data_root=data_root,
                                use_cuda=USE_CUDA,
                                device=device)
        model_leaves = tree_model.leaf_list

        simple_model = SimpleFCNN(input_dim=num_features)
        if USE_CUDA:
            tree_model.to(device)
            simple_model.to(device)

        dump_dict(config, output_dir)
        experiment = RegressionExperiment(tree_model,
                                          simple_model,
                                          config,
                                          model_leaves,
                                          data_root,
                                          leaves,
                                          use_test=False)
        # train_leaf_starting_weights = list()
        # for i in range(3):
        #     train_leaf_starting_weights.append(model_leaves[2].weight_list[i].detach().numpy())
        delta_losses, prediction_losses, validation_losses, _ = experiment.train_dendronet(
        )
        simple_prediction_losses, simple_validation_losses, _ = experiment.train_simple_model(
        )
        # todo: omitted a block of code that collects that targets and final predictions for every model at every seed
Ejemplo n.º 6
0
    def __init__(
            self,
            device=None,
            logfile='_logs/_logs_experiment.txt',
            save_name_model='convnet/convnet_experiment.pt',
            readers=[],
            metadata_file='_Data_metadata_letters.pk',
            dump_config=True,
            encoding_fn=Data.letters_encoding,
            score_fn=F.softmax,
            list_files_fn=Data.get_openslr_files,
            process_file_fn=Data.read_and_slice_signal,
            signal_type='window-sliced',
            slice_fn=Data.window_slicing_signal,
            multi_head=False,
            d_keys_values=64,
            lr=1e-4,
            smoothing_eps=0.1,
            n_epochs=500,
            batch_size=32,
            decay_factor=1,
            decay_step=0.01,
            create_enc_mask=False,
            eval_step=10,
            scorer=Data.compute_accuracy,
            convnet_config={},
            relu=False,
            pin_memory=True,
            train_folder='../../../datasets/openslr/LibriSpeech/train-clean-100/',
            test_folder='../../../datasets/openslr/LibriSpeech/test-clean/',
            scores_step=5,
            **kwargs):
        '''
    Params:
      * device (optional) : torch.device
      * logfile (optional) : str, filename for logs dumping
      * save_name_model (optional) : str, filename of model saving
      * readers (optional) : list of str
      * metadata_file (optional) : str, filename where metadata from Data are saved
      * dump_config (optional) : bool, True to dump convnet configuration to logging file
      * encoding_fn (optional) : function, handle text encoding
      * score_fn (optional) : function, handle energy computation for attention mechanism
      * list_files_fn (optional) : function, handles files list retrieval
      * process_file_fn (optional) : function, reads and process audio files
      * signal_type (optional) : str
      * slice_fn (optional) : function, handles audio raw signal framing
      * multi_head (optional) : bool, True to use a MultiHeadAttention mechanism
      * d_keys_values (optional) : int, key/values dimension for the multihead-attention
      * lr (optional) : float, learning rate passed to the optimizer
      * smoothing_eps (optional) : float, Label-Smoothing epsilon
      * n_epochs (optional) : int
      * batch_size (optional) : int
      * decay_factor (optional) : int, 0 for cross-entropy loss only, 1 for Attention-CrossEntropy loss
      * decay_step (optional) : float, decreasing step of Attention loss
      * create_enc_mask (optional) : bool
      * eval_step (optional) : int, computes accuracies on test set when (epoch % eval_step == 0)
      * scorer (optional) : function, computes training and testing metrics
      * convnet_config (optional) : dict
      * relu (optional) : bool, True to use ReLU version of ConvEncoder&Decoder
      * pin_memory (optional) : bool, passed to DataLoader
      * train_folder (optional) : str
      * test_folder (optional) : str
      * kwargs (optional) : arguments passed to process_file_fn
    '''
        # [logging.root.removeHandler(handler) for handler in logging.root.handlers[:]]
        logging.basicConfig(filename=logfile,
                            filemode='a',
                            level=logging.INFO,
                            format='%(asctime)s - %(levelname)s - %(message)s')

        self.device = torch.device('cuda' if torch.cuda.is_available(
        ) else 'cpu') if device is None else device
        self.logfile = logfile
        self.save_name_model = save_name_model
        self.readers = readers
        self.metadata_file = metadata_file
        self.dump_config = dump_config
        self.encoding_fn = encoding_fn
        self.score_fn = score_fn
        self.list_files_fn = list_files_fn
        self.process_file_fn = process_file_fn
        self.signal_type = signal_type
        self.slice_fn = slice_fn
        self.multi_head = multi_head
        self.d_keys_values = d_keys_values
        self.lr = lr
        self.smoothing_eps = smoothing_eps
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.decay_factor = decay_factor
        self.decay_step = decay_step
        self.create_enc_mask = create_enc_mask
        self.eval_step = eval_step
        self.scorer = scorer
        self.relu = relu
        self.pin_memory = pin_memory
        self.train_folder = train_folder
        self.test_folder = test_folder
        self.scores_step = scores_step
        self.process_file_fn_args = {**kwargs, **{'slice_fn': slice_fn}}

        self.set_data()
        self.sos_idx = self.data.tokens_to_idx['<sos>']
        self.eos_idx = self.data.tokens_to_idx['<eos>']
        self.pad_idx = self.data.tokens_to_idx['<pad>']

        self.convnet_config = {
            'enc_input_dim': self.data.n_signal_feats,
            'enc_max_seq_len': self.data.max_signal_len,
            'dec_input_dim': len(self.data.idx_to_tokens),
            'dec_max_seq_len': self.data.max_source_len,
            'output_size': len(self.data.idx_to_tokens),
            'pad_idx': self.pad_idx,
            'score_fn': score_fn,
            'enc_layers': 10,
            'dec_layers': 10,
            'enc_kernel_size': 3,
            'dec_kernel_size': 3,
            'enc_dropout': 0.25,
            'dec_dropout': 0.25,
            'emb_dim': 256,
            'hid_dim': 512,
            'reduce_dim': False,
            'multi_head': multi_head,
            'd_keys_values': d_keys_values
        }
        self.convnet_config = {**self.convnet_config, **convnet_config}
        self.model = self.instanciate_model(**self.convnet_config)

        if dump_config:
            u.dump_dict(self.convnet_config, 'ENCODER-DECODER PARAMETERS')
            logging.info(
                f'The model has {u.count_trainable_parameters(self.model):,} trainable parameters'
            )

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = u.AttentionLoss(self.pad_idx,
                                         self.device,
                                         decay_step=decay_step,
                                         decay_factor=decay_factor)

        self.set_data_loader()
def run(args):
    dirpath = Path(args['dirpath'])
    # dname = args['dname']
    # src_names = args['src_names']

    # Target
    target_name = args['target_name']

    # Data split
    cv_folds = args['cv_folds']

    # Features
    cell_fea = args['cell_features']
    drug_fea = args['drug_features']
    other_fea = args['other_features']
    fea_list = cell_fea + drug_fea + other_fea

    # NN params
    epochs = args['epochs']
    batch_size = args['batch_size']
    dr_rate = args['dr_rate']

    # Optimizer
    opt_name = args['opt']
    clr_keras_kwargs = {
        'mode': args['clr_mode'],
        'base_lr': args['clr_base_lr'],
        'max_lr': args['clr_max_lr'],
        'gamma': args['clr_gamma']
    }

    # Learning curve
    n_shards = args['n_shards']

    # Other params
    # framework = args['framework']
    model_name = args['model_name']
    n_jobs = args['n_jobs']

    # ML type ('reg' or 'cls')
    if 'reg' in model_name:
        mltype = 'reg'
    elif 'cls' in model_name:
        mltype = 'cls'
    else:
        raise ValueError("model_name must contain 'reg' or 'cls'.")

    # Define metrics
    # metrics = {'r2': 'r2',
    #            'neg_mean_absolute_error': 'neg_mean_absolute_error', #sklearn.metrics.neg_mean_absolute_error,
    #            'neg_median_absolute_error': 'neg_median_absolute_error', #sklearn.metrics.neg_median_absolute_error,
    #            'neg_mean_squared_error': 'neg_mean_squared_error', #sklearn.metrics.neg_mean_squared_error,
    #            'reg_auroc_score': utils.reg_auroc_score}

    # ========================================================================
    #       Load data and pre-proc
    # ========================================================================
    dfs = {}

    def get_file(fpath):
        return pd.read_csv(
            fpath, header=None).squeeze().values if fpath.is_file() else None

    def read_data_file(fpath, file_format='csv'):
        fpath = Path(fpath)
        if fpath.is_file():
            if file_format == 'csv':
                df = pd.read_csv(fpath)
            elif file_format == 'parquet':
                df = pd.read_parquet(fpath)
        else:
            df = None
        return df

    if dirpath is not None:
        xdata = read_data_file(dirpath / 'xdata.parquet', 'parquet')
        meta = read_data_file(dirpath / 'meta.parquet', 'parquet')
        ydata = meta[[target_name]]

        tr_id = pd.read_csv(dirpath / f'{cv_folds}fold_tr_id.csv')
        vl_id = pd.read_csv(dirpath / f'{cv_folds}fold_vl_id.csv')

        # tr_ids_list = get_file( dirpath/f'{cv_folds}fold_tr_id.csv' )
        # vl_ids_list = get_file( dirpath/f'{cv_folds}fold_vl_id.csv' )
        # te_ids_list = get_file( dirpath/f'{cv_folds}fold_te_id.csv' )

        src = dirpath.name.split('_')[0]
        dfs[src] = (ydata, xdata, tr_id, vl_id)

    elif dname == 'combined':
        # TODO: this is not used anymore (probably won't work)
        DATADIR = file_path / '../../data/processed/data_splits'
        DATAFILENAME = 'data.parquet'
        dirs = glob(str(DATADIR / '*'))

        for src in src_names:
            print(f'\n{src} ...')
            subdir = f'{src}_cv_{cv_method}'
            if str(DATADIR / subdir) in dirs:
                # Get the CV indexes
                tr_id = pd.read_csv(DATADIR / subdir /
                                    f'{cv_folds}fold_tr_id.csv')
                vl_id = pd.read_csv(DATADIR / subdir /
                                    f'{cv_folds}fold_vl_id.csv')

                # Get the data
                datapath = DATADIR / subdir / DATAFILENAME
                data = pd.read_parquet(datapath)
                xdata, _, meta, _ = break_src_data(
                    data, target=None, scaler=None)  # logger=lg.logger
                ydata = meta[[target_name]]

                dfs[src] = (ydata, xdata, tr_id, vl_id)
                del data, xdata, ydata, tr_id, vl_id, src

    for src, data in dfs.items():
        ydata, xdata, tr_id, vl_id = data[0], data[1], data[2], data[3]

        # Scale
        scaler = args['scaler']
        if scaler is not None:
            if scaler == 'stnd':
                scaler = StandardScaler()
            elif scaler == 'minmax':
                scaler = MinMaxScaler()
            elif scaler == 'rbst':
                scaler = RobustScaler()

        cols = xdata.columns
        xdata = pd.DataFrame(scaler.fit_transform(xdata),
                             columns=cols,
                             dtype=np.float32)

        # -----------------------------------------------
        #       Create outdir and logger
        # -----------------------------------------------
        run_outdir = create_outdir(OUTDIR, args, src)
        lg = Logger(run_outdir / 'logfile.log')
        lg.logger.info(f'File path: {file_path}')
        lg.logger.info(f'\n{pformat(args)}')

        # Dump args to file
        utils.dump_dict(args, outpath=run_outdir / 'args.txt')

        # -----------------------------------------------
        #      ML model configs
        # -----------------------------------------------
        if model_name == 'lgb_reg':
            framework = 'lightgbm'
            init_kwargs = {
                'n_jobs': n_jobs,
                'random_state': SEED,
                'logger': lg.logger
            }
            fit_kwargs = {'verbose': False}
        elif model_name == 'nn_reg':
            framework = 'keras'
            init_kwargs = {
                'input_dim': xdata.shape[1],
                'dr_rate': dr_rate,
                'opt_name': opt_name,
                'attn': attn,
                'logger': lg.logger
            }
            fit_kwargs = {
                'batch_size': batch_size,
                'epochs': epochs,
                'verbose': 1
            }
        elif model_name == 'nn_reg0' or 'nn_reg1' or 'nn_reg2':
            framework = 'keras'
            init_kwargs = {
                'input_dim': xdata.shape[1],
                'dr_rate': dr_rate,
                'opt_name': opt_name,
                'logger': lg.logger
            }
            fit_kwargs = {
                'batch_size': batch_size,
                'epochs': epochs,
                'verbose': 1
            }  # 'validation_split': 0.1
        elif model_name == 'nn_reg3' or 'nn_reg4':
            framework = 'keras'
            init_kwargs = {
                'in_dim_rna': None,
                'in_dim_dsc': None,
                'dr_rate': dr_rate,
                'opt_name': opt_name,
                'logger': lg.logger
            }
            fit_kwargs = {
                'batch_size': batch_size,
                'epochs': epochs,
                'verbose': 1
            }  # 'validation_split': 0.1

        # -----------------------------------------------
        #      Learning curve
        # -----------------------------------------------
        lg.logger.info('\n\n{}'.format('=' * 50))
        lg.logger.info(f'Learning curves {src} ...')
        lg.logger.info('=' * 50)

        t0 = time()
        lc = LearningCurve(X=xdata,
                           Y=ydata,
                           cv=None,
                           cv_lists=(tr_id, vl_id),
                           n_shards=n_shards,
                           shard_step_scale='log10',
                           args=args,
                           logger=lg.logger,
                           outdir=run_outdir)

        lrn_crv_scores = lc.trn_learning_curve(
            framework=framework,
            mltype=mltype,
            model_name=model_name,
            init_kwargs=init_kwargs,
            fit_kwargs=fit_kwargs,
            clr_keras_kwargs=clr_keras_kwargs,
            n_jobs=n_jobs,
            random_state=SEED)

        lg.logger.info('Runtime: {:.1f} hrs'.format((time() - t0) / 360))

        # -------------------------------------------------
        # Learning curve (sklearn method)
        # Problem! cannot log multiple metrics.
        # -------------------------------------------------
        """
        lg.logger.info('\nStart learning curve (sklearn method) ...')
        # Define params
        metric_name = 'neg_mean_absolute_error'
        base = 10
        train_sizes_frac = np.logspace(0.0, 1.0, lc_ticks, endpoint=True, base=base)/base

        # Run learning curve
        t0 = time()
        lrn_curve_scores = learning_curve(
            estimator=model.model, X=xdata, y=ydata,
            train_sizes=train_sizes_frac, cv=cv, groups=groups,
            scoring=metric_name,
            n_jobs=n_jobs, exploit_incremental_learning=False,
            random_state=SEED, verbose=1, shuffle=False)
        lg.logger.info('Runtime: {:.1f} mins'.format( (time()-t0)/60) )

        # Dump results
        # lrn_curve_scores = utils.cv_scores_to_df(lrn_curve_scores, decimals=3, calc_stats=False) # this func won't work
        # lrn_curve_scores.to_csv(os.path.join(run_outdir, 'lrn_curve_scores_auto.csv'), index=False)

        # Plot learning curves
        lrn_crv.plt_learning_curve(rslt=lrn_curve_scores, metric_name=metric_name,
            title='Learning curve (target: {}, data: {})'.format(target_name, tr_sources_name),
            path=os.path.join(run_outdir, 'auto_learning_curve_' + target_name + '_' + metric_name + '.png'))
        """

        lg.kill_logger()
        del xdata, ydata

    print('Done.')
Ejemplo n.º 8
0
                # pd.DataFrame.from_records(train_y).to_csv(name + '_labels.csv')
                #
                # """
                # end temp dump
                # """

                par_dendro_losses, par_prediction_losses, par_validation_losses, par_validation_aucs = baseline_experiment.train_dendronet()
                one_hot_prediction_losses, one_hot_validation_losses, one_hot_validation_aucs = baseline_experiment.train_simple_model()

            tree_model = TreeModelLogReg(data_tree_root=data_root, leaves=leaves, layer_shape=layer_shape)
            simple_model = LogRegModel(layer_shape=layer_shape)
            input_shape = (weights_dim, 1)
            simple_model.build(input_shape)


            dump_dict(config, output_dir)
            experiment = LogRegExperiment(tree_model, simple_model, config, data_root, leaves)
            dendronet_losses, prediction_losses, validation_losses, validation_aucs = experiment.train_dendronet()
            simple_prediction_losses, simple_validation_losses, simple_validation_aucs = experiment.train_simple_model()

            """
            saving the tree graph
            """
            # tree_fig_model_root, tree_fig_data_root = search_for_tree_starting_point(tree_model, data_root, feature_index=FEATURE_INDEX)
            # print_tree_model(tree_fig_model_root, root=tree_fig_data_root, lifestyle=ls, feature_index=FEATURE_INDEX)
            tree_labels_file_name = 'OB_fig_labels.csv'
            dendro_predictions = tree_model.call(experiment.all_x)[0].numpy()
            simple_predictions = [simple_model.call(leaf.x).numpy() for leaf in leaves]
            fig_targets = [leaf.y for leaf in leaves]
            fig_species_names = [leaf.name for leaf in leaves]
            fig_dict = {
Ejemplo n.º 9
0
        if args.baselines:
            parsimony_model = MultifurcatingTreeModelLogReg(
                data_tree_root=data_root, leaves=leaves, layer_shape=(1, 2))

            baseline_experiment = LogRegExperiment(parsimony_model,
                                                   None,
                                                   config,
                                                   data_root,
                                                   leaves,
                                                   baselines=args.baselines,
                                                   expanded_x=None,
                                                   use_test=True)
            par_dendro_losses, par_prediction_losses, par_validation_losses, par_validation_aucs = baseline_experiment.train_dendronet(
            )

        dump_dict(config, output_dir)

        x_label = str(config['validation_interval']) + "s of steps"
        if args.baselines:
            plot_file = os.path.join(output_dir, 'parsimony_losses.png')
            plot_losses(plot_file, [
                par_dendro_losses, par_prediction_losses, par_validation_losses
            ], ['mutation', 'training', 'validation'], x_label)

        if len(par_validation_aucs) >= 1:
            if args.baselines:
                aucs['parsimony_best'].append(max(par_validation_aucs[1:]))
                aucs['parsimony_final'].append(par_validation_aucs[-1])

    auc_list = list()
    log_auc_names = list()
from utils import dump_dict
from data_structures.entangled_data_simulation import EntangledTree, store_tree_and_leaves

seeds = [0, 1, 2, 3, 4]
base_folder = 'tree_storage/seed_'

config = {
    'depth': 9,
    'mutation_rate': 0.1,
    'num_leaves': 1,
    'low': 0,
    'high': 5,
}

print('Generating trees with seeds ' + str(seeds))
print('Using config ' + str(config))
for seed in seeds:
    data_tree = EntangledTree(seed=seed,
                              depth=config['depth'],
                              mutation_rate=config['mutation_rate'],
                              num_leaves=config['num_leaves'],
                              low=config['low'],
                              high=config['high'])
    folder_name = base_folder + str(seed)
    store_tree_and_leaves(data_tree.tree, data_tree.leaves, folder_name)
    dump_dict(config, folder_name)
    print('Stored a tree in folder ' + folder_name)
print('All trees created and stored')