예제 #1
0
    def set_trainability(self, model):
        opts = self.opts
        trainable = []
        not_trainable = []
        if opts.fine_tune:
            not_trainable.append('.*')
        elif opts.train_models:
            not_trainable.append('.*')
            for name in opts.train_models:
                trainable.append('%s/' % name)
        if opts.freeze_filter:
            not_trainable.append(mod.get_first_conv_layer(model.layers).name)
        if not trainable and opts.trainable:
            trainable = opts.trainable
        if not not_trainable and opts.not_trainable:
            not_trainable = opts.not_trainable

        if not trainable and not not_trainable:
            return

        table = OrderedDict()
        table['layer'] = []
        table['trainable'] = []
        for layer in model.layers:
            if is_input_layer(layer) or is_output_layer(layer, model):
                continue
            if not hasattr(layer, 'trainable'):
                continue
            for regex in not_trainable:
                if re.match(regex, layer.name):
                    layer.trainable = False
            for regex in trainable:
                if re.match(regex, layer.name):
                    layer.trainable = True
            table['layer'].append(layer.name)
            table['trainable'].append(layer.trainable)
        print('Layer trainability:')
        print(format_table(table))
        print()
예제 #2
0
    def set_trainability(self, model):
        opts = self.opts
        trainable = []
        not_trainable = []
        if opts.fine_tune:
            not_trainable.append('.*')
        elif opts.train_models:
            not_trainable.append('.*')
            for name in opts.train_models:
                trainable.append('%s/' % name)
        if opts.freeze_filter:
            not_trainable.append(mod.get_first_conv_layer(model.layers).name)
        if not trainable and opts.trainable:
            trainable = opts.trainable
        if not not_trainable and opts.not_trainable:
            not_trainable = opts.not_trainable

        if not trainable and not not_trainable:
            return

        table = OrderedDict()
        table['layer'] = []
        table['trainable'] = []
        for layer in model.layers:
            if is_input_layer(layer) or is_output_layer(layer, model):
                continue
            if not hasattr(layer, 'trainable'):
                continue
            for regex in not_trainable:
                if re.match(regex, layer.name):
                    layer.trainable = False
            for regex in trainable:
                if re.match(regex, layer.name):
                    layer.trainable = True
            table['layer'].append(layer.name)
            table['trainable'].append(layer.trainable)
        print('Layer trainability:')
        print(format_table(table))
        print()
예제 #3
0
    def set_trainability(self, model):
        opts = self.opts
        trainable = []  #create a list
        not_trainable = []  #create a list
        if opts.fine_tune: #only train output layers
            not_trainable.append('.*')
        elif opts.train_models:  #only train the specified model, including dna, cpg, and joint
            not_trainable.append('.*')
            for name in opts.train_models:
                trainable.append('%s/' % name)
        if opts.freeze_filter:  #Exclude filter weights of first convolutional layer from training
            not_trainable.append(mod.get_first_conv_layer(model.layers).name)
        if not trainable and opts.trainable:
            trainable = opts.trainable
        if not not_trainable and opts.not_trainable:
            not_trainable = opts.not_trainable

        if not trainable and not not_trainable:
            return

        table = OrderedDict() #dictionary which remember the order
        table['layer'] = []
        table['trainable'] = []
        for layer in model.layers:
            if layer not in model.input_layers + model.output_layers:
                if not hasattr(layer, 'trainable'):
                    continue
                for regex in not_trainable:
                    if re.match(regex, layer.name):
                        layer.trainable = False
                for regex in trainable:
                    if re.match(regex, layer.name):
                        layer.trainable = True
                table['layer'].append(layer.name)
                table['trainable'].append(layer.trainable)
        print('Layer trainability:')
        print(format_table(table))
        print()
예제 #4
0
    def main(self, name, opts):
        logging.basicConfig(filename=opts.log_file,
                            format='%(levelname)s (%(asctime)s): %(message)s')
        log = logging.getLogger(name)
        if opts.verbose:
            log.setLevel(logging.DEBUG)
        else:
            log.setLevel(logging.INFO)

        if opts.seed is not None:
            np.random.seed(opts.seed)
            random.seed(opts.seed)

        self.log = log
        self.opts = opts

        make_dir(opts.out_dir)

        log.info('Building model ...')
        model = self.build_model()

        model.summary()
        self.set_trainability(model)
        if opts.filter_weights:
            conv_layer = mod.get_first_conv_layer(model.layers)
            log.info('Initializing filters of %s ...' % conv_layer.name)
            self.init_filter_weights(opts.filter_weights, conv_layer)
        mod.save_model(model, os.path.join(opts.out_dir, 'model.json'))

        log.info('Computing output statistics ...')
        output_names = []
        for output_layer in model.output_layers:
            output_names.append(output_layer.name)

        output_stats = OrderedDict()

        if opts.no_class_weights:
            class_weights = None
        else:
            class_weights = OrderedDict()

        for name in output_names:
            output = hdf.read(opts.train_files,
                              'outputs/%s' % name,
                              nb_sample=opts.nb_train_sample)
            output = list(output.values())[0]
            output_stats[name] = get_output_stats(output)
            if class_weights is not None:
                class_weights[name] = get_output_class_weights(name, output)

        self.print_output_stats(output_stats)
        if class_weights:
            self.print_class_weights(class_weights)

        output_weights = None
        if opts.output_weights:
            log.info('Initializing output weights ...')
            output_weights = get_output_weights(output_names,
                                                opts.output_weights)
            print('Output weights:')
            for output_name in output_names:
                if output_name in output_weights:
                    print('%s: %.2f' %
                          (output_name, output_weights[output_name]))
            print()

        self.metrics = dict()
        for output_name in output_names:
            self.metrics[output_name] = get_metrics(output_name)

        optimizer = Adam(lr=opts.learning_rate)
        model.compile(optimizer=optimizer,
                      loss=mod.get_objectives(output_names),
                      loss_weights=output_weights,
                      metrics=self.metrics)

        log.info('Loading data ...')
        replicate_names = dat.get_replicate_names(opts.train_files[0],
                                                  regex=opts.replicate_names,
                                                  nb_key=opts.nb_replicate)
        data_reader = mod.data_reader_from_model(
            model, replicate_names=replicate_names)
        nb_train_sample = dat.get_nb_sample(opts.train_files,
                                            opts.nb_train_sample)
        train_data = data_reader(opts.train_files,
                                 class_weights=class_weights,
                                 batch_size=opts.batch_size,
                                 nb_sample=nb_train_sample,
                                 shuffle=True,
                                 loop=True)

        if opts.val_files:
            nb_val_sample = dat.get_nb_sample(opts.val_files,
                                              opts.nb_val_sample)
            val_data = data_reader(opts.val_files,
                                   batch_size=opts.batch_size,
                                   nb_sample=nb_val_sample,
                                   shuffle=False,
                                   loop=True)
        else:
            val_data = None
            nb_val_sample = None

        log.info('Initializing callbacks ...')
        callbacks = self.get_callbacks()

        log.info('Training model ...')
        print()
        print('Training samples: %d' % nb_train_sample)
        if nb_val_sample:
            print('Validation samples: %d' % nb_val_sample)
        model.fit_generator(train_data,
                            nb_train_sample,
                            opts.nb_epoch,
                            callbacks=callbacks,
                            validation_data=val_data,
                            nb_val_samples=nb_val_sample,
                            max_q_size=opts.data_q_size,
                            nb_worker=opts.data_nb_worker,
                            verbose=0)

        print('\nTraining set performance:')
        print(
            format_table(self.perf_logger.epoch_logs, precision=LOG_PRECISION))

        if self.perf_logger.val_epoch_logs:
            print('\nValidation set performance:')
            print(
                format_table(self.perf_logger.val_epoch_logs,
                             precision=LOG_PRECISION))

        # Restore model with highest validation performance
        filename = os.path.join(opts.out_dir, 'model_weights_val.h5')
        if os.path.isfile(filename):
            model.load_weights(filename)

        # Delete metrics since they cause problems when loading the model
        # from HDF5 file. Metrics can be loaded from json + weights file.
        model.metrics = None
        model.metrics_names = None
        model.metrics_tensors = None
        model.save(os.path.join(opts.out_dir, 'model.h5'))

        log.info('Done!')

        return 0
예제 #5
0
    def main(self, name, opts):
        logging.basicConfig(filename=opts.log_file,
                            format='%(levelname)s (%(asctime)s): %(message)s')
        log = logging.getLogger(name)
        if opts.verbose:
            log.setLevel(logging.DEBUG)
        else:
            log.setLevel(logging.INFO)
            log.debug(opts)

        if opts.seed is not None:
            np.random.seed(opts.seed)

        if not opts.model_files:
            raise ValueError('No model files provided!')

        log.info('Loading model ...')
        K.set_learning_phase(0)
        model = mod.load_model(opts.model_files, log=log.info)

        weight_layer, act_layer = mod.get_first_conv_layer(model.layers, True)
        log.info('Using activation layer "%s"' % act_layer.name)
        log.info('Using weight layer "%s"' % weight_layer.name)

        try:
            dna_idx = model.input_names.index('dna')
        except BaseException:
            raise IOError('Model is not a valid DNA model!')

        fun_outputs = to_list(act_layer.output)
        if opts.store_preds:
            fun_outputs += to_list(model.output)
        fun = K.function([to_list(model.input)[dna_idx]], fun_outputs)

        log.info('Reading data ...')
        if opts.store_outputs or opts.store_preds:
            output_names = model.output_names
        else:
            output_names = None
        data_reader = mod.DataReader(
            output_names=output_names,
            use_dna=True,
            dna_wlen=to_list(model.input_shape)[dna_idx][1]
        )
        nb_sample = dat.get_nb_sample(opts.data_files, opts.nb_sample)
        data_reader = data_reader(opts.data_files,
                                  nb_sample=nb_sample,
                                  batch_size=opts.batch_size,
                                  loop=False,
                                  shuffle=opts.shuffle)

        meta_reader = hdf.reader(opts.data_files, ['chromo', 'pos'],
                                 nb_sample=nb_sample,
                                 batch_size=opts.batch_size,
                                 loop=False,
                                 shuffle=False)

        out_file = h5.File(opts.out_file, 'w')
        out_group = out_file

        weights = weight_layer.get_weights()
        out_group['weights/weights'] = weights[0]
        out_group['weights/bias'] = weights[1]

        def h5_dump(path, data, idx, dtype=None, compression='gzip'):
            if path not in out_group:
                if dtype is None:
                    dtype = data.dtype
                out_group.create_dataset(
                    name=path,
                    shape=[nb_sample] + list(data.shape[1:]),
                    dtype=dtype,
                    compression=compression
                )
            out_group[path][idx:idx+len(data)] = data

        log.info('Computing activations')
        progbar = ProgressBar(nb_sample, log.info)
        idx = 0
        for data in data_reader:
            if isinstance(data, tuple):
                inputs, outputs, weights = data
            else:
                inputs = data
            if isinstance(inputs, dict):
                inputs = list(inputs.values())
            batch_size = len(inputs[0])
            progbar.update(batch_size)

            if opts.store_inputs:
                for i, name in enumerate(model.input_names):
                    h5_dump('inputs/%s' % name,
                            dna.onehot_to_int(inputs[i]), idx)

            if opts.store_outputs:
                for name, output in six.iteritems(outputs):
                    h5_dump('outputs/%s' % name, output, idx)

            fun_eval = fun(inputs)
            act = fun_eval[0]

            if opts.act_wlen:
                delta = opts.act_wlen // 2
                ctr = act.shape[1] // 2
                act = act[:, (ctr-delta):(ctr+delta+1)]

            if opts.act_fun:
                if opts.act_fun == 'mean':
                    act = act.mean(axis=1)
                elif opts.act_fun == 'wmean':
                    weights = linear_weights(act.shape[1])
                    act = np.average(act, axis=1, weights=weights)
                elif opts.act_fun == 'max':
                    act = act.max(axis=1)
                else:
                    raise ValueError('Invalid function "%s"!' % (opts.act_fun))

            h5_dump('act', act, idx)

            if opts.store_preds:
                preds = fun_eval[1:]
                for i, name in enumerate(model.output_names):
                    h5_dump('preds/%s' % name, preds[i].squeeze(), idx)

            for name, value in six.iteritems(next(meta_reader)):
                h5_dump(name, value, idx)

            idx += batch_size
        progbar.close()

        out_file.close()
        log.info('Done!')

        return 0
예제 #6
0
    def main(self, name, opts):
        logging.basicConfig(filename=opts.log_file,
                            format='%(levelname)s (%(asctime)s): %(message)s')
        log = logging.getLogger(name)
        if opts.verbose:
            log.setLevel(logging.DEBUG)
        else:
            log.setLevel(logging.INFO)
            log.debug(opts)

        if opts.seed is not None:
            np.random.seed(opts.seed)

        if not opts.model_files:
            raise ValueError('No model files provided!')

        log.info('Loading model ...')
        K.set_learning_phase(0)
        model = mod.load_model(opts.model_files, log=log.info)

        weight_layer, act_layer = mod.get_first_conv_layer(model.layers, True)
        log.info('Using activation layer "%s"' % act_layer.name)
        log.info('Using weight layer "%s"' % weight_layer.name)

        try:
            dna_idx = model.input_names.index('dna')
        except BaseException:
            raise IOError('Model is not a valid DNA model!')

        fun_outputs = to_list(act_layer.output)
        if opts.store_preds:
            fun_outputs += to_list(model.output)
        fun = K.function([to_list(model.input)[dna_idx]], fun_outputs)

        log.info('Reading data ...')
        if opts.store_outputs or opts.store_preds:
            output_names = model.output_names
        else:
            output_names = None
        data_reader = mod.DataReader(output_names=output_names,
                                     use_dna=True,
                                     dna_wlen=to_list(
                                         model.input_shape)[dna_idx][1])
        nb_sample = dat.get_nb_sample(opts.data_files, opts.nb_sample)
        data_reader = data_reader(opts.data_files,
                                  nb_sample=nb_sample,
                                  batch_size=opts.batch_size,
                                  loop=False,
                                  shuffle=False)

        meta_reader = hdf.reader(opts.data_files, ['chromo', 'pos'],
                                 nb_sample=nb_sample,
                                 batch_size=opts.batch_size,
                                 loop=False,
                                 shuffle=False)

        out_file = h5.File(opts.out_file, 'w')
        out_group = out_file

        weights = weight_layer.get_weights()
        out_group['weights/weights'] = weights[0]
        out_group['weights/bias'] = weights[1]

        def h5_dump(path, data, idx, dtype=None, compression='gzip'):
            if path not in out_group:
                if dtype is None:
                    dtype = data.dtype
                out_group.create_dataset(name=path,
                                         shape=[nb_sample] +
                                         list(data.shape[1:]),
                                         dtype=dtype,
                                         compression=compression)
            out_group[path][idx:idx + len(data)] = data

        log.info('Computing activations')
        progbar = ProgressBar(nb_sample, log.info)
        idx = 0
        for data in data_reader:
            if isinstance(data, tuple):
                inputs, outputs, weights = data
            else:
                inputs = data
            if isinstance(inputs, dict):
                inputs = list(inputs.values())
            batch_size = len(inputs[0])
            progbar.update(batch_size)

            if opts.store_inputs:
                for i, name in enumerate(model.input_names):
                    h5_dump('inputs/%s' % name, dna.onehot_to_int(inputs[i]),
                            idx)

            if opts.store_outputs:
                for name, output in six.iteritems(outputs):
                    h5_dump('outputs/%s' % name, output, idx)

            fun_eval = fun(inputs)
            act = fun_eval[0]

            if opts.act_wlen:
                delta = opts.act_wlen // 2
                ctr = act.shape[1] // 2
                act = act[:, (ctr - delta):(ctr + delta + 1)]

            if opts.act_fun:
                if opts.act_fun == 'mean':
                    act = act.mean(axis=1)
                elif opts.act_fun == 'wmean':
                    weights = linear_weights(act.shape[1])
                    act = np.average(act, axis=1, weights=weights)
                elif opts.act_fun == 'max':
                    act = act.max(axis=1)
                else:
                    raise ValueError('Invalid function "%s"!' % (opts.act_fun))

            h5_dump('act', act, idx)

            if opts.store_preds:
                preds = fun_eval[1:]
                for i, name in enumerate(model.output_names):
                    h5_dump('preds/%s' % name, preds[i].squeeze(), idx)

            for name, value in six.iteritems(next(meta_reader)):
                h5_dump(name, value, idx)

            idx += batch_size
        progbar.close()

        out_file.close()
        log.info('Done!')

        return 0
예제 #7
0
    def main(self, name, opts):
        logging.basicConfig(filename=opts.log_file,
                            format='%(levelname)s (%(asctime)s): %(message)s')
        log = logging.getLogger(name)
        if opts.verbose:
            log.setLevel(logging.DEBUG)
        else:
            log.setLevel(logging.INFO)

        if opts.seed is not None:
            np.random.seed(opts.seed)
            random.seed(opts.seed)

        self.log = log
        self.opts = opts

        make_dir(opts.out_dir)

        log.info('Building model ...')
        model = self.build_model()

        model.summary()
        self.set_trainability(model)
        if opts.filter_weights:
            conv_layer = mod.get_first_conv_layer(model.layers)
            log.info('Initializing filters of %s ...' % conv_layer.name)
            self.init_filter_weights(opts.filter_weights, conv_layer)
        mod.save_model(model, os.path.join(opts.out_dir, 'model.json'))

        log.info('Computing output statistics ...')
        output_names = model.output_names

        output_stats = OrderedDict()

        if opts.no_class_weights:
            class_weights = None
        else:
            class_weights = OrderedDict()

        for name in output_names:
            output = hdf.read(opts.train_files, 'outputs/%s' % name,
                              nb_sample=opts.nb_train_sample)
            output = list(output.values())[0]
            output_stats[name] = get_output_stats(output)
            if class_weights is not None:
                class_weights[name] = get_output_class_weights(name, output)

        self.print_output_stats(output_stats)
        if class_weights:
            self.print_class_weights(class_weights)

        output_weights = None
        if opts.output_weights:
            log.info('Initializing output weights ...')
            output_weights = get_output_weights(output_names,
                                                opts.output_weights)
            print('Output weights:')
            for output_name in output_names:
                if output_name in output_weights:
                    print('%s: %.2f' % (output_name,
                                        output_weights[output_name]))
            print()

        self.metrics = dict()
        for output_name in output_names:
            self.metrics[output_name] = get_metrics(output_name)

        optimizer = Adam(lr=opts.learning_rate)
        model.compile(optimizer=optimizer,
                      loss=mod.get_objectives(output_names),
                      loss_weights=output_weights,
                      metrics=self.metrics)

        log.info('Loading data ...')
        replicate_names = dat.get_replicate_names(
            opts.train_files[0],
            regex=opts.replicate_names,
            nb_key=opts.nb_replicate)
        data_reader = mod.data_reader_from_model(
            model, replicate_names=replicate_names)
        nb_train_sample = dat.get_nb_sample(opts.train_files,
                                            opts.nb_train_sample)
        train_data = data_reader(opts.train_files,
                                 class_weights=class_weights,
                                 batch_size=opts.batch_size,
                                 nb_sample=nb_train_sample,
                                 shuffle=True,
                                 loop=True)

        if opts.val_files:
            nb_val_sample = dat.get_nb_sample(opts.val_files,
                                              opts.nb_val_sample)
            val_data = data_reader(opts.val_files,
                                   batch_size=opts.batch_size,
                                   nb_sample=nb_val_sample,
                                   shuffle=False,
                                   loop=True)
        else:
            val_data = None
            nb_val_sample = None

        log.info('Initializing callbacks ...')
        callbacks = self.get_callbacks()

        log.info('Training model ...')
        print()
        print('Training samples: %d' % nb_train_sample)
        if nb_val_sample:
            print('Validation samples: %d' % nb_val_sample)
        model.fit_generator(
            train_data,
            steps_per_epoch=nb_train_sample // opts.batch_size,
            epochs=opts.nb_epoch,
            callbacks=callbacks,
            validation_data=val_data,
            validation_steps=nb_val_sample // opts.batch_size,
            max_queue_size=opts.data_q_size,
            workers=opts.data_nb_worker,
            verbose=0)

        print('\nTraining set performance:')
        print(format_table(self.perf_logger.epoch_logs,
                           precision=LOG_PRECISION))

        if self.perf_logger.val_epoch_logs:
            print('\nValidation set performance:')
            print(format_table(self.perf_logger.val_epoch_logs,
                               precision=LOG_PRECISION))

        # Restore model with highest validation performance
        filename = os.path.join(opts.out_dir, 'model_weights_val.h5')
        if os.path.isfile(filename):
            model.load_weights(filename)

        # Delete metrics since they cause problems when loading the model
        # from HDF5 file. Metrics can be loaded from json + weights file.
        model.metrics = None
        model.metrics_names = None
        model.metrics_tensors = None
        model.save(os.path.join(opts.out_dir, 'model.h5'))

        log.info('Done!')

        return 0