def run(params):
    args = Struct(**params)
    set_seed(args.rng_seed)
    ext = extension_from_parameters(args)
    verify_path(args.save_path)
    prefix = args.save_path + ext
    logfile = args.logfile if args.logfile else prefix + '.log'
    set_up_logger(logfile, args.verbose)
    logger.info('Params: {}'.format(params))

    if (len(args.gpus) > 0):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = ",".join(map(str, args.gpus))
        K.set_session(tf.Session(config=config))

    loader = CombinedDataLoader(seed=args.rng_seed)
    loader.load(
        cache=args.cache,
        ncols=args.feature_subsample,
        agg_dose=args.agg_dose,
        cell_features=args.cell_features,
        drug_features=args.drug_features,
        drug_median_response_min=args.drug_median_response_min,
        drug_median_response_max=args.drug_median_response_max,
        use_landmark_genes=args.use_landmark_genes,
        use_filtered_genes=args.use_filtered_genes,
        cell_feature_subset_path=args.cell_feature_subset_path
        or args.feature_subset_path,
        drug_feature_subset_path=args.drug_feature_subset_path
        or args.feature_subset_path,
        preprocess_rnaseq=args.preprocess_rnaseq,
        single=args.single,
        train_sources=args.train_sources,
        test_sources=args.test_sources,
        embed_feature_source=not args.no_feature_source,
        encode_response_source=not args.no_response_source,
    )

    target = args.agg_dose or 'Growth'
    val_split = args.validation_split
    train_split = 1 - val_split

    if args.export_csv:
        fname = args.export_csv
        loader.partition_data(cv_folds=args.cv,
                              train_split=train_split,
                              val_split=val_split,
                              cell_types=args.cell_types,
                              by_cell=args.by_cell,
                              by_drug=args.by_drug,
                              cell_subset_path=args.cell_subset_path,
                              drug_subset_path=args.drug_subset_path)
        train_gen = CombinedDataGenerator(loader,
                                          batch_size=args.batch_size,
                                          shuffle=args.shuffle)
        val_gen = CombinedDataGenerator(loader,
                                        partition='val',
                                        batch_size=args.batch_size,
                                        shuffle=args.shuffle)

        x_train_list, y_train = train_gen.get_slice(size=train_gen.size,
                                                    dataframe=True,
                                                    single=args.single)
        x_val_list, y_val = val_gen.get_slice(size=val_gen.size,
                                              dataframe=True,
                                              single=args.single)
        df_train = pd.concat([y_train] + x_train_list, axis=1)
        df_val = pd.concat([y_val] + x_val_list, axis=1)
        df = pd.concat([df_train, df_val]).reset_index(drop=True)
        if args.growth_bins > 1:
            df = uno_data.discretize(df, 'Growth', bins=args.growth_bins)
        df.to_csv(fname, sep='\t', index=False, float_format="%.3g")
        return

    if args.export_data:
        fname = args.export_data
        loader.partition_data(cv_folds=args.cv,
                              train_split=train_split,
                              val_split=val_split,
                              cell_types=args.cell_types,
                              by_cell=args.by_cell,
                              by_drug=args.by_drug,
                              cell_subset_path=args.cell_subset_path,
                              drug_subset_path=args.drug_subset_path)
        train_gen = CombinedDataGenerator(loader,
                                          batch_size=args.batch_size,
                                          shuffle=args.shuffle)
        val_gen = CombinedDataGenerator(loader,
                                        partition='val',
                                        batch_size=args.batch_size,
                                        shuffle=args.shuffle)
        store = pd.HDFStore(fname, complevel=9, complib='blosc:snappy')

        config_min_itemsize = {'Sample': 30, 'Drug1': 10}
        if not args.single:
            config_min_itemsize['Drug2'] = 10

        for partition in ['train', 'val']:
            gen = train_gen if partition == 'train' else val_gen
            for i in range(gen.steps):
                x_list, y = gen.get_slice(size=args.batch_size,
                                          dataframe=True,
                                          single=args.single)

                for j, input_feature in enumerate(x_list):
                    input_feature.columns = [''] * len(input_feature.columns)
                    store.append('x_{}_{}'.format(partition, j),
                                 input_feature.astype('float32'),
                                 format='table',
                                 data_column=True)
                store.append('y_{}'.format(partition),
                             y.astype({target: 'float32'}),
                             format='table',
                             data_column=True,
                             min_itemsize=config_min_itemsize)
                logger.info('Generating {} dataset. {} / {}'.format(
                    partition, i, gen.steps))
        store.close()
        logger.info('Completed generating {}'.format(fname))
        return

    loader.partition_data(cv_folds=args.cv,
                          train_split=train_split,
                          val_split=val_split,
                          cell_types=args.cell_types,
                          by_cell=args.by_cell,
                          by_drug=args.by_drug,
                          cell_subset_path=args.cell_subset_path,
                          drug_subset_path=args.drug_subset_path)

    model = build_model(loader, args)
    logger.info('Combined model:')
    model.summary(print_fn=logger.info)
    # plot_model(model, to_file=prefix+'.model.png', show_shapes=True)

    if args.cp:
        model_json = model.to_json()
        with open(prefix + '.model.json', 'w') as f:
            print(model_json, file=f)

    def warmup_scheduler(epoch):
        lr = args.learning_rate or base_lr * args.batch_size / 100
        if epoch <= 5:
            K.set_value(model.optimizer.lr,
                        (base_lr * (5 - epoch) + lr * epoch) / 5)
        logger.debug('Epoch {}: lr={:.5g}'.format(
            epoch, K.get_value(model.optimizer.lr)))
        return K.get_value(model.optimizer.lr)

    df_pred_list = []

    cv_ext = ''
    cv = args.cv if args.cv > 1 else 1

    for fold in range(cv):
        if args.cv > 1:
            logger.info('Cross validation fold {}/{}:'.format(fold + 1, cv))
            cv_ext = '.cv{}'.format(fold + 1)

        template_model = build_model(loader, args, silent=True)
        if args.initial_weights:
            logger.info("Loading weights from {}".format(args.initial_weights))
            template_model.load_weights(args.initial_weights)

        if len(args.gpus) > 1:
            from keras.utils import multi_gpu_model
            gpu_count = len(args.gpus)
            logger.info("Multi GPU with {} gpus".format(gpu_count))
            model = multi_gpu_model(template_model,
                                    cpu_merge=False,
                                    gpus=gpu_count)
        else:
            model = template_model

        optimizer = optimizers.deserialize({
            'class_name': args.optimizer,
            'config': {}
        })
        base_lr = args.base_lr or K.get_value(optimizer.lr)
        if args.learning_rate:
            K.set_value(optimizer.lr, args.learning_rate)

        model.compile(loss=args.loss, optimizer=optimizer, metrics=[mae, r2])

        # calculate trainable and non-trainable params
        params.update(candle.compute_trainable_params(model))

        candle_monitor = candle.CandleRemoteMonitor(params=params)
        timeout_monitor = candle.TerminateOnTimeOut(params['timeout'])

        reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                      factor=0.5,
                                      patience=5,
                                      min_lr=0.00001)
        warmup_lr = LearningRateScheduler(warmup_scheduler)
        checkpointer = MultiGPUCheckpoint(prefix + cv_ext + '.model.h5',
                                          save_best_only=True)
        tensorboard = TensorBoard(
            log_dir="tb/{}{}{}".format(args.tb_prefix, ext, cv_ext))
        history_logger = LoggingCallback(logger.debug)

        callbacks = [candle_monitor, timeout_monitor, history_logger]
        if args.reduce_lr:
            callbacks.append(reduce_lr)
        if args.warmup_lr:
            callbacks.append(warmup_lr)
        if args.cp:
            callbacks.append(checkpointer)
        if args.tb:
            callbacks.append(tensorboard)
        if args.save_weights:
            callbacks.append(
                SimpleWeightSaver(args.save_path + '/' + args.save_weights))

        if args.use_exported_data is not None:
            train_gen = DataFeeder(filename=args.use_exported_data,
                                   batch_size=args.batch_size,
                                   shuffle=args.shuffle,
                                   single=args.single,
                                   agg_dose=args.agg_dose)
            val_gen = DataFeeder(partition='val',
                                 filename=args.use_exported_data,
                                 batch_size=args.batch_size,
                                 shuffle=args.shuffle,
                                 single=args.single,
                                 agg_dose=args.agg_dose)
        else:
            train_gen = CombinedDataGenerator(loader,
                                              fold=fold,
                                              batch_size=args.batch_size,
                                              shuffle=args.shuffle,
                                              single=args.single)
            val_gen = CombinedDataGenerator(loader,
                                            partition='val',
                                            fold=fold,
                                            batch_size=args.batch_size,
                                            shuffle=args.shuffle,
                                            single=args.single)

        df_val = val_gen.get_response(copy=True)
        y_val = df_val[target].values
        y_shuf = np.random.permutation(y_val)
        log_evaluation(evaluate_prediction(y_val, y_shuf),
                       description='Between random pairs in y_val:')

        if args.no_gen:
            x_train_list, y_train = train_gen.get_slice(size=train_gen.size,
                                                        single=args.single)
            x_val_list, y_val = val_gen.get_slice(size=val_gen.size,
                                                  single=args.single)
            history = model.fit(x_train_list,
                                y_train,
                                batch_size=args.batch_size,
                                epochs=args.epochs,
                                callbacks=callbacks,
                                validation_data=(x_val_list, y_val))
        else:
            logger.info('Data points per epoch: train = %d, val = %d',
                        train_gen.size, val_gen.size)
            logger.info('Steps per epoch: train = %d, val = %d',
                        train_gen.steps, val_gen.steps)
            history = model.fit_generator(train_gen,
                                          train_gen.steps,
                                          epochs=args.epochs,
                                          callbacks=callbacks,
                                          validation_data=val_gen,
                                          validation_steps=val_gen.steps)

        if args.no_gen:
            y_val_pred = model.predict(x_val_list, batch_size=args.batch_size)
        else:
            val_gen.reset()
            y_val_pred = model.predict_generator(val_gen, val_gen.steps + 1)
            y_val_pred = y_val_pred[:val_gen.size]

        y_val_pred = y_val_pred.flatten()

        scores = evaluate_prediction(y_val, y_val_pred)
        log_evaluation(scores)

        # df_val = df_val.assign(PredictedGrowth=y_val_pred, GrowthError=y_val_pred - y_val)
        df_val['Predicted' + target] = y_val_pred
        df_val[target + 'Error'] = y_val_pred - y_val
        df_pred_list.append(df_val)

        if hasattr(history, 'loss'):
            plot_history(prefix, history, 'loss')
        if hasattr(history, 'r2'):
            plot_history(prefix, history, 'r2')

    pred_fname = prefix + '.predicted.tsv'
    df_pred = pd.concat(df_pred_list)
    if args.agg_dose:
        if args.single:
            df_pred.sort_values(['Sample', 'Drug1', target], inplace=True)
        else:
            df_pred.sort_values(['Source', 'Sample', 'Drug1', 'Drug2', target],
                                inplace=True)
    else:
        if args.single:
            df_pred.sort_values(['Sample', 'Drug1', 'Dose1', 'Growth'],
                                inplace=True)
        else:
            df_pred.sort_values(
                ['Sample', 'Drug1', 'Drug2', 'Dose1', 'Dose2', 'Growth'],
                inplace=True)
    df_pred.to_csv(pred_fname, sep='\t', index=False, float_format='%.4g')

    if args.cv > 1:
        scores = evaluate_prediction(df_pred[target],
                                     df_pred['Predicted' + target])
        log_evaluation(scores, description='Combining cross validation folds:')

    for test_source in loader.test_sep_sources:
        test_gen = CombinedDataGenerator(loader,
                                         partition='test',
                                         batch_size=args.batch_size,
                                         source=test_source)
        df_test = test_gen.get_response(copy=True)
        y_test = df_test[target].values
        n_test = len(y_test)
        if n_test == 0:
            continue
        if args.no_gen:
            x_test_list, y_test = test_gen.get_slice(size=test_gen.size,
                                                     single=args.single)
            y_test_pred = model.predict(x_test_list,
                                        batch_size=args.batch_size)
        else:
            y_test_pred = model.predict_generator(
                test_gen.flow(single=args.single), test_gen.steps)
            y_test_pred = y_test_pred[:test_gen.size]
        y_test_pred = y_test_pred.flatten()
        scores = evaluate_prediction(y_test, y_test_pred)
        log_evaluation(scores,
                       description='Testing on data from {} ({})'.format(
                           test_source, n_test))

    if K.backend() == 'tensorflow':
        K.clear_session()

    logger.handlers = []

    return history
def run(params):
    args = Struct(**params)
    set_seed(args.rng_seed)
    ext = extension_from_parameters(args)
    verify_path(args.save)
    prefix = args.save + ext
    logfile = args.logfile if args.logfile else prefix + '.log'
    set_up_logger(logfile, args.verbose)
    logger.info('Params: {}'.format(params))

    loader = CombinedDataLoader(seed=args.rng_seed)
    loader.load(
        cache=args.cache,
        ncols=args.feature_subsample,
        cell_features=args.cell_features,
        drug_features=args.drug_features,
        drug_median_response_min=args.drug_median_response_min,
        drug_median_response_max=args.drug_median_response_max,
        use_landmark_genes=args.use_landmark_genes,
        use_filtered_genes=args.use_filtered_genes,
        preprocess_rnaseq=args.preprocess_rnaseq,
        single=args.single,
        train_sources=args.train_sources,
        test_sources=args.test_sources,
        embed_feature_source=not args.no_feature_source,
        encode_response_source=not args.no_response_source,
    )

    val_split = args.validation_split
    train_split = 1 - val_split

    if args.export_data:
        fname = args.export_data
        loader.partition_data(cv_folds=args.cv,
                              train_split=train_split,
                              val_split=val_split,
                              cell_types=args.cell_types,
                              by_cell=args.by_cell,
                              by_drug=args.by_drug)
        train_gen = CombinedDataGenerator(loader,
                                          batch_size=args.batch_size,
                                          shuffle=args.shuffle)
        val_gen = CombinedDataGenerator(loader,
                                        partition='val',
                                        batch_size=args.batch_size,
                                        shuffle=args.shuffle)
        x_train_list, y_train = train_gen.get_slice(size=train_gen.size,
                                                    dataframe=True,
                                                    single=args.single)
        x_val_list, y_val = val_gen.get_slice(size=val_gen.size,
                                              dataframe=True,
                                              single=args.single)
        df_train = pd.concat([y_train] + x_train_list, axis=1)
        df_val = pd.concat([y_val] + x_val_list, axis=1)
        df = pd.concat([df_train, df_val]).reset_index(drop=True)
        if args.growth_bins > 1:
            df = uno_data.discretize(df, 'Growth', bins=args.growth_bins)
        df.to_csv(fname, sep='\t', index=False, float_format="%.3g")
        return

    loader.partition_data(cv_folds=args.cv,
                          train_split=train_split,
                          val_split=val_split,
                          cell_types=args.cell_types,
                          by_cell=args.by_cell,
                          by_drug=args.by_drug)

    model = build_model(loader, args)
    logger.info('Combined model:')
    # model.summary(print_fn=logger.info)
    # plot_model(model, to_file=prefix+'.model.png', show_shapes=True)

    if args.cp:
        model_json = model.to_json()
        with open(prefix + '.model.json', 'w') as f:
            print(model_json, file=f)

    def warmup_scheduler(epoch):
        lr = args.learning_rate or base_lr * args.batch_size / 100
        if epoch <= 5:
            K.set_value(model.optimizer.lr,
                        (base_lr * (5 - epoch) + lr * epoch) / 5)
        logger.debug('Epoch {}: lr={:.5g}'.format(
            epoch, K.get_value(model.optimizer.lr)))
        return K.get_value(model.optimizer.lr)

    df_pred_list = []

    cv_ext = ''
    cv = args.cv if args.cv > 1 else 1

    for fold in range(cv):
        if args.cv > 1:
            logger.info('Cross validation fold {}/{}:'.format(fold + 1, cv))
            cv_ext = '.cv{}'.format(fold + 1)

        model = build_model(loader, args, silent=True)

        optimizer = optimizers.deserialize({
            'class_name': args.optimizer,
            'config': {}
        })
        base_lr = args.base_lr or K.get_value(optimizer.lr)
        if args.learning_rate:
            K.set_value(optimizer.lr, args.learning_rate)

        model.compile(loss=args.loss, optimizer=optimizer, metrics=[mae, r2])

        # calculate trainable and non-trainable params
        params.update(candle.compute_trainable_params(model))

        candle_monitor = candle.CandleRemoteMonitor(params=params)
        timeout_monitor = candle.TerminateOnTimeOut(params['timeout'])

        reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                      factor=0.5,
                                      patience=5,
                                      min_lr=0.00001)
        warmup_lr = LearningRateScheduler(warmup_scheduler)
        checkpointer = ModelCheckpoint(prefix + cv_ext + '.weights.h5',
                                       save_best_only=True,
                                       save_weights_only=True)
        tensorboard = TensorBoard(log_dir="tb/tb{}{}".format(ext, cv_ext))
        history_logger = LoggingCallback(logger.debug)
        model_recorder = ModelRecorder()

        # callbacks = [history_logger, model_recorder]
        callbacks = [
            candle_monitor, timeout_monitor, history_logger, model_recorder
        ]
        if args.reduce_lr:
            callbacks.append(reduce_lr)
        if args.warmup_lr:
            callbacks.append(warmup_lr)
        if args.cp:
            callbacks.append(checkpointer)
        if args.tb:
            callbacks.append(tensorboard)

        train_gen = CombinedDataGenerator(loader,
                                          fold=fold,
                                          batch_size=args.batch_size,
                                          shuffle=args.shuffle)
        val_gen = CombinedDataGenerator(loader,
                                        partition='val',
                                        fold=fold,
                                        batch_size=args.batch_size,
                                        shuffle=args.shuffle)

        df_val = val_gen.get_response(copy=True)
        y_val = df_val['Growth'].values
        y_shuf = np.random.permutation(y_val)
        log_evaluation(evaluate_prediction(y_val, y_shuf),
                       description='Between random pairs in y_val:')

        candleRemoteMonitor = CandleRemoteMonitor(params=params)

        callbacks.append(candleRemoteMonitor)

        if args.no_gen:
            x_train_list, y_train = train_gen.get_slice(size=train_gen.size,
                                                        single=args.single)
            x_val_list, y_val = val_gen.get_slice(size=val_gen.size,
                                                  single=args.single)
            history = model.fit(x_train_list,
                                y_train,
                                batch_size=args.batch_size,
                                epochs=args.epochs,
                                callbacks=callbacks,
                                validation_data=(x_val_list, y_val))
        else:
            logger.info('Data points per epoch: train = %d, val = %d',
                        train_gen.size, val_gen.size)
            logger.info('Steps per epoch: train = %d, val = %d',
                        train_gen.steps, val_gen.steps)
            history = model.fit_generator(
                train_gen.flow(single=args.single),
                train_gen.steps,
                epochs=args.epochs,
                callbacks=callbacks,
                validation_data=val_gen.flow(single=args.single),
                validation_steps=val_gen.steps)

        if args.cp:
            model.load_weights(prefix + cv_ext + '.weights.h5')
        # model = model_recorder.best_model

        if args.no_gen:
            y_val_pred = model.predict(x_val_list, batch_size=args.batch_size)
        else:
            val_gen.reset()
            y_val_pred = model.predict_generator(
                val_gen.flow(single=args.single), val_gen.steps)
            y_val_pred = y_val_pred[:val_gen.size]

        y_val_pred = y_val_pred.flatten()

        scores = evaluate_prediction(y_val, y_val_pred)
        log_evaluation(scores)

        df_val = df_val.assign(PredictedGrowth=y_val_pred,
                               GrowthError=y_val_pred - y_val)
        df_pred_list.append(df_val)

        plot_history(prefix, history, 'loss')
        plot_history(prefix, history, 'r2')

    pred_fname = prefix + '.predicted.tsv'
    df_pred = pd.concat(df_pred_list)
    df_pred.sort_values(
        ['Source', 'Sample', 'Drug1', 'Drug2', 'Dose1', 'Dose2', 'Growth'],
        inplace=True)
    df_pred.to_csv(pred_fname, sep='\t', index=False, float_format='%.4g')

    if args.cv > 1:
        scores = evaluate_prediction(df_pred['Growth'],
                                     df_pred['PredictedGrowth'])
        log_evaluation(scores, description='Combining cross validation folds:')

    for test_source in loader.test_sep_sources:
        test_gen = CombinedDataGenerator(loader,
                                         partition='test',
                                         batch_size=args.batch_size,
                                         source=test_source)
        df_test = test_gen.get_response(copy=True)
        y_test = df_test['Growth'].values
        n_test = len(y_test)
        if n_test == 0:
            continue
        if args.no_gen:
            x_test_list, y_test = test_gen.get_slice(size=test_gen.size,
                                                     single=args.single)
            y_test_pred = model.predict(x_test_list,
                                        batch_size=args.batch_size)
        else:
            y_test_pred = model.predict_generator(
                test_gen.flow(single=args.single), test_gen.steps)
            y_test_pred = y_test_pred[:test_gen.size]
        y_test_pred = y_test_pred.flatten()
        scores = evaluate_prediction(y_test, y_test_pred)
        log_evaluation(scores,
                       description='Testing on data from {} ({})'.format(
                           test_source, n_test))

    if K.backend() == 'tensorflow':
        K.clear_session()

    logger.handlers = []

    return history