Beispiel #1
0
    def fit(self,
            dataloader,
            nb_iter=None,
            nb_epoch=None,
            iter_per_epoch=None,
            callbacks=[],
            verbose=0):
        """Trains the underlying Keras model.

        Args:
            dataloader (StandardDataLoader): Manages the loading of data to
                model.
            nb_iter (int): The number of iterations to train the model.
            nb_epoch (int): The number of epochs to train the model.
            iter_per_epoch (int): Defines the number of iterations per epoch.
            callbacks (list): List of Keras callbacks to run during training.
        """
        nb_iter, iter_per_epoch = self._get_iterations(nb_iter, nb_epoch,
                                                       iter_per_epoch)
        callbacks = CallbackList(callbacks)
        callbacks._set_model(self)
        callbacks.on_train_begin()

        try:
            epoch = 0
            self.stop_training = False
            for i in xrange(nb_iter):
                # Begin epoch
                if i % iter_per_epoch == 0:
                    callbacks.on_epoch_begin(epoch)

                # Execution
                callbacks.on_batch_begin(i)

                if verbose > 0:
                    import time
                    time.sleep(0.001)
                    j = i % iter_per_epoch
                    perc = int(100 * (j + 1) / iter_per_epoch)
                    prog = ''.join(['='] * (perc / 2))
                    string = "[{:50s}] {:3d}%\r".format(prog, perc)
                    sys.stdout.write(string)
                    sys.stdout.flush()

                losses = self.keras_model.train_on_batch(
                    *dataloader.get_training_batch())

                callbacks.on_batch_end(i)

                # End epoch
                if (i + 1) % iter_per_epoch == 0:
                    callbacks.on_epoch_end(epoch, logs={'losses': losses})
                    epoch += 1
                if self.stop_training:
                    break
        except KeyboardInterrupt:
            print "\n[BayesNet] Abort: KeyboardInterrupt"
            raise

        callbacks.on_train_end()
Beispiel #2
0
    def fit(self, dataloader, nb_iter=None, nb_epoch=None, iter_per_epoch=None,
            callbacks=[], verbose=0):
        """Trains the underlying Keras model.

        Args:
            dataloader (StandardDataLoader): Manages the loading of data to
                model.
            nb_iter (int): The number of iterations to train the model.
            nb_epoch (int): The number of epochs to train the model.
            iter_per_epoch (int): Defines the number of iterations per epoch.
            callbacks (list): List of Keras callbacks to run during training.
        """
        nb_iter, iter_per_epoch = self._get_iterations(
            nb_iter, nb_epoch, iter_per_epoch)
        callbacks = CallbackList(callbacks)
        callbacks._set_model(self)
        callbacks.on_train_begin()

        try:
            epoch = 0
            self.stop_training = False
            for i in xrange(nb_iter):
                # Begin epoch
                if i % iter_per_epoch == 0:
                    callbacks.on_epoch_begin(epoch)

                # Execution
                callbacks.on_batch_begin(i)

                if verbose > 0:
                    import time
                    time.sleep(0.001)
                    j = i % iter_per_epoch
                    perc = int(100 * (j + 1) /iter_per_epoch)
                    prog = ''.join(['='] * (perc/2))
                    string = "[{:50s}] {:3d}%\r".format(prog, perc)
                    sys.stdout.write(string); sys.stdout.flush()

                losses = self.keras_model.train_on_batch(
                    *dataloader.get_training_batch())

                callbacks.on_batch_end(i)

                # End epoch
                if (i + 1) % iter_per_epoch == 0:
                    callbacks.on_epoch_end(epoch, logs={'losses': losses})
                    epoch += 1
                if self.stop_training:
                    break
        except KeyboardInterrupt:
            print "\n[BayesNet] Abort: KeyboardInterrupt"
            raise

        callbacks.on_train_end()
Beispiel #3
0
            os.system("rm -rf *.weights")


    # EV 10-Jan-2021: Broadcast initial variable states from rank 0 to all other processes
    # EV 06-Fev-2021: add hvd.callbacks.MetricAverageCallback()
    
    gcb = CallbackList([hvd.callbacks.BroadcastGlobalVariablesCallback(0), hvd.callbacks.MetricAverageCallback()])
    dcb = CallbackList([hvd.callbacks.BroadcastGlobalVariablesCallback(0), hvd.callbacks.MetricAverageCallback()])
    ccb = CallbackList([hvd.callbacks.BroadcastGlobalVariablesCallback(0), hvd.callbacks.MetricAverageCallback()])

    gcb.set_model( generator )
    dcb.set_model( discriminator )
    ccb.set_model( combined )


    gcb.on_train_begin()
    dcb.on_train_begin()
    ccb.on_train_begin()

    logger.info('commencing training')

    for epoch in range(last_epoch+1, nb_epochs+last_epoch+1):

        logger.info('Epoch {} of {}'.format(epoch + 1, nb_epochs+last_epoch+1))
        nb_batches = int(first.shape[0] / batch_size)
        
        train_gan(epoch, nb_batches)

        # save weights every epoch
        # EV 10-Jan-2021: this needs to done only on one process. Otherwise each worker is writing it.
        if ((hvd.rank()==0) or (not process0)) and (save_all_epochs or epoch==nb_epochs+last_epoch):
Beispiel #4
0
def evaluate(model, save_path, num_outputs, liver_only=False, **kwargs):
    model, callbacks, gen = prepare_model(model=model,
                                          save_path=save_path,
                                          num_outputs=num_outputs,
                                          liver_only=liver_only,
                                          **kwargs)

    print(' > Evaluating the model...')
    from scipy.misc import imsave

    # Create directory, if needed
    save_predictions_to = os.path.join(save_path, "predictions")
    if not os.path.exists(save_predictions_to):
        os.makedirs(save_predictions_to)

    # Initialize callbacks
    val_callback_list = [BaseLogger()]
    if not liver_only:
        val_callback_list.extend(
            [callbacks['dice_lesion'], callbacks['dice_lesion_inliver']])
    if len(model.outputs) == 2 or liver_only:
        val_callback_list.append(callbacks['dice_liver'])
    val_callbacks = CallbackList(val_callback_list)
    val_callbacks.set_params({
        'nb_epoch': 0,
        'nb_sample': 0,
        'verbose': False,
        'do_validation': True,
        'metrics': model.metrics_names
    })
    val_callbacks.on_train_begin()
    val_callbacks.on_epoch_begin(0)

    # Create theano function
    inputs = model.inputs + model.targets + model.sample_weights
    if model.uses_learning_phase and \
            not isinstance(K.learning_phase(), int):
        inputs += [K.learning_phase()]
    predict_and_test_function = K.function( \
        inputs,
        model.outputs+[model.total_loss]+model.metrics_tensors,
        updates=model.state_updates)

    # Loop through batches, applying function and callbacks
    flow = repeat_flow(gen['valid_callback'].flow(), num_outputs=num_outputs)
    for batch_num, batch in enumerate(flow):
        x, y, sample_weights = model._standardize_user_data(batch[0], batch[1])
        ins = x + y + sample_weights
        if model.uses_learning_phase and \
                not isinstance(K.learning_phase(), int):
            ins += [0.]
        outputs = predict_and_test_function(ins)
        if num_outputs == 1:
            predictions = outputs[0:1]
            val_metrics = outputs[1:]
        else:
            predictions = outputs[0:2]
            val_metrics = outputs[2:]

        ## Save images
        #def process_slice(s):
        #s = np.squeeze(s).copy()
        #s[s<0]=0
        #s[s>1]=1
        #s[0,0]=1
        #s[0,1]=0
        #return s
        #for i in range(len(batch[0])):
        #s_pred_list = []
        #if num_outputs==1:
        #s_pred_list = [process_slice(predictions[i])]
        #else:
        #for j in range(num_outputs):
        #s_pred_list.append(process_slice(predictions[j][i]))
        #s_input = process_slice(batch[0][i])
        #if num_outputs==1:
        #s_truth = process_slice(batch[1][i]/2.)
        #else:
        #s_truth = process_slice(batch[1][0][i]/2.)
        #out_image = np.concatenate([s_input]+s_pred_list+[s_truth],
        #axis=1)
        #imsave(os.path.join(save_predictions_to,
        #"{}_{}.png".format(batch_num, i)),
        #out_image)

        # Update metrics
        val_logs = OrderedDict(zip(model.metrics_names, val_metrics))
        val_logs.update({'batch': batch_num, 'size': len(batch[0])})
        val_callbacks.on_batch_end(batch_num, val_logs)

    # Update metrics
    val_callbacks.on_epoch_end(0, val_logs)

    # Output metrics
    for m in val_logs:
        if m not in ['batch', 'size']:
            print("{}: {}".format(m, val_logs[m]))
Beispiel #5
0
    def fit_generator(self,
                      generator,
                      epochs=1,
                      validation_data=None,
                      callbacks=None,
                      verbose=True):
        method = self._model.optimizer.method
        x0 = self._collect_weights()
        history = History()
        _callbacks = [BaseLogger(stateful_metrics=self._model.metrics_names)]
        _callbacks += (callbacks or []) + [history]
        callback_list = CallbackList(_callbacks)
        callback_list.set_model(self._model)
        callback_list.set_params({
            'epochs': epochs,
            'verbose': False,
            'metrics': list(self._model.metrics_names),
        })
        state = {
            'epoch': 0,
            'verbose': verbose,
            'callbacks': callback_list,
            'in_epoch': False,
            'epoch_logs': {},
        }
        min_options = {
            'maxiter': epochs,
            'maxfun': epochs * 10,
            'maxcor': 50,
            'maxls': 50,
            'ftol': np.finfo(float).eps,
            'gtol': 1e-10,
            'eps': 1e-8,
        }

        val_generator = None
        if validation_data is not None:
            if isinstance(validation_data, keras.utils.Sequence):
                val_generator = validation_data
            elif isinstance(validation_data,
                            tuple) and len(validation_data) == 2:
                val_generator = GeneratorWrapper(*validation_data)

        def on_iteration_end(xk):
            cb = state['callbacks']
            if val_generator is not None:
                self._validate(xk, val_generator, state)
            cb.on_epoch_end(state['epoch'], state['epoch_logs'])
            # if state['verbose']:
            #     epoch_logs = state['epoch_logs']
            #     print('epoch: ', state['epoch'],
            #           ', '.join([' {0}: {1:.3e}'.format(k, v) for k, v in epoch_logs.items()]))
            state['epoch'] += 1
            state['in_epoch'] = False
            state['epoch_logs'] = {}

        callback_list.on_train_begin()
        result = minimize(self._fun_generator,
                          x0,
                          method=method,
                          jac=True,
                          options=min_options,
                          callback=on_iteration_end,
                          args=(generator, state))
        self._update_weights(result['x'])
        callback_list.on_train_end()
        return history
Beispiel #6
0
    def fit(self,
            x,
            y,
            batch_size,
            n_epochs=1,
            callbacks=None,
            validation_data=None):
        """Trains the network on the given data for a fixed number of epochs

        :param x: input data to train on
        :type x: torch.Tensor
        :param y: target data to train on
        :type y: torch.Tensor
        :param batch_size: number of samples to use per forward and backward
         pass
        :type batch_size: int
        :param n_epochs: number of epochs (iterations of the dataset) to train
         the model
        :type n_epochs: int
        :param callbacks: callbacks to be used during training
        :type callbacks: list[object]
        :param validation_data: data on which to evaluate the loss and metrics
         at the end of each epoch
        :type validation_data: tuple(numpy.ndarray)
        """

        default_callbacks = self._load_default_callbacks()
        default_callbacks.append(ProgbarLogger(count_mode='samples'))
        if callbacks:
            default_callbacks.extend(callbacks)
        callbacks = CallbackList(default_callbacks)

        self._assert_compiled()

        if self.device:
            self.network.to(self.device)

        metrics = ['loss']
        if self.n_outputs > 1:
            for idx_output in range(1, self.n_outputs + 1):
                metrics.append('loss{}'.format(idx_output))
        if validation_data is not None:
            metrics.append('val_loss')
            if self.n_outputs > 1:
                for idx_output in range(1, self.n_outputs + 1):
                    metrics.append('val_loss{}'.format(idx_output))
        for metric_name in self.metric_names:
            metrics.append(metric_name)
            if validation_data is not None:
                metrics.append('val_{}'.format(metric_name))

        index_array = np.arange(x.shape[0])

        callbacks.set_params({
            'batch_size': batch_size,
            'epochs': n_epochs,
            'metrics': metrics,
            'steps': None,
            'samples': x.shape[0],
            'verbose': True
        })
        callbacks.set_model(self)

        callbacks.on_train_begin()
        for idx_epoch in range(n_epochs):
            if self.stop_training:
                break

            epoch_logs = {}
            callbacks.on_epoch_begin(idx_epoch)

            np.random.shuffle(index_array)
            batches = make_batches(len(index_array), batch_size)
            for idx_batch, (idx_start, idx_end) in enumerate(batches):
                batch_logs = {'batch': idx_batch, 'size': idx_end - idx_start}
                callbacks.on_batch_begin(idx_batch, batch_logs)

                inputs = x[index_array[idx_start:idx_end]]
                if self.n_outputs > 1:
                    targets = []
                    for idx_output in range(self.n_outputs):
                        targets.append(
                            y[idx_output][index_array[idx_start:idx_end]])
                else:
                    targets = y[index_array[idx_start:idx_end]]
                train_outputs = self.train_on_batch(inputs, targets)

                batch_logs['loss'] = train_outputs[0]
                if self.n_outputs > 1:
                    for idx_output in range(1, self.n_outputs + 1):
                        batch_logs['loss{}'.format(idx_output)] = (
                            train_outputs[idx_output])

                idx_metric_values = (1 if self.n_outputs == 1 else
                                     self.n_outputs + 1)
                it = zip(self.metric_names, train_outputs[idx_metric_values:])
                for metric_name, train_output in it:
                    batch_logs[metric_name] = train_output
                callbacks.on_batch_end(idx_batch, batch_logs)

                if self.stop_training:
                    break

            if validation_data:
                val_outputs = self.evaluate(validation_data[0],
                                            validation_data[1], batch_size)

                epoch_logs['val_loss'] = val_outputs[0]
                if self.n_outputs > 1:
                    for idx_output in range(1, self.n_outputs + 1):
                        epoch_logs['val_loss{}'.format(idx_output)] = (
                            val_outputs[idx_output])

                idx_metric_values = (1 if self.n_outputs == 1 else
                                     self.n_outputs + 1)
                it = zip(self.metric_names, val_outputs[idx_metric_values:])
                for metric_name, val_output in it:
                    metric_name = 'val_{}'.format(metric_name)
                    epoch_logs[metric_name] = val_output
            callbacks.on_epoch_end(idx_epoch, epoch_logs)
        callbacks.on_train_end()
Beispiel #7
0
def main(args):

    try:
        opts, args = getopt.getopt(args, "c:s", ["config="])
    except getopt.GetoptError:
        print('usage: -c config.json')
        sys.exit(2)

    start_from_model = False
    for opt, arg in opts:
        if opt in ("-c", "--config"):
            config_fname = os.path.join('configurations', arg)
        elif opt == '-s':
            start_from_model = True

    if start_from_model:
        filemode = 'a'
    else:
        filemode = 'w'

    log_path = 'logging/vae_nlg_{}'.format(int(round(time.time() * 1000)))
    os.mkdir(log_path)

    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                        level=logging.INFO,
                        filename='{}/evolution.log'.format(log_path),
                        filemode=filemode)

    with open(config_fname, 'r') as json_data:
        config_data = json.load(json_data)

        batch_size = config_data['batch_size']
        epochs = config_data['nb_epochs']
        discriminator_iterations = config_data['discriminator_iterations']
        tweets_path = config_data['tweets_path']
        vocab_path = config_data['vocab_path']
        vocab = cPickle.load(open(join(vocab_path, 'vocabulary.pkl'), 'rb'))

        #== == == == == == =
        # Load all the Data
        #== == == == == == =
        delimiter = ''
        noutputs = 11

        logging.info('Load Training Data')
        train_input, train_output, train_weights, train_lex = load_text_gen_data(
            join(tweets_path, 'trainset.csv'),
            config_data,
            vocab,
            noutputs,
            word_based=False)
        logging.info('Load Validation Data')
        valid_input, valid_output, _, valid_lex = load_text_gen_data(
            join(tweets_path, 'devset.csv'),
            config_data,
            vocab,
            noutputs,
            word_based=False)
        logging.info('Load Output Validation Data')
        valid_dev_input, valid_dev_output, _, valid_dev_lex = load_text_gen_data(
            join(tweets_path, 'devset_reduced.csv'),
            config_data,
            vocab,
            noutputs,
            random_output=False,
            word_based=False)

        step = K.variable(1., name='step_varialbe')
        steps_per_epoch = ceil(train_output[0].shape[0] /
                               config_data['batch_size'])
        # == == == == == == == == == == =
        # Define and load the CNN model
        # == == == == == == == == == == =
        vae_model_train, vae_model_test, vae_vanilla_train_model, vae_vanilla_test_model, discriminator_model, decoder_test, discriminator = get_vae_gan_model(
            config_data, vocab, step)
        with open(os.path.join(log_path, 'models.txt'), 'wt') as fh:
            fh.write('VAE Model Train\n')
            fh.write('---------\n')
            vae_model_train.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('VAE Model Test\n')
            fh.write('--------------\n')
            vae_model_test.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('VAE Model Pretrain\n')
            fh.write('---------------------------\n')
            vae_vanilla_train_model.summary(
                print_fn=lambda x: fh.write(x + '\n'))
            fh.write('VAE Model Pretrain Test\n')
            fh.write('---------------------------\n')
            vae_vanilla_test_model.summary(
                print_fn=lambda x: fh.write(x + '\n'))
            fh.write('Decoder Test\n')
            fh.write('-------------------\n')
            decoder_test.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('Discriminator Models\n')
            fh.write('-------------------\n')
            discriminator_model.summary(print_fn=lambda x: fh.write(x + '\n'))

        terminate_on_nan = TerminateOnNaN()
        output_callback = LexOutputCallback(
            vae_vanilla_test_model,
            valid_dev_input,
            valid_dev_lex,
            1,
            vocab,
            delimiter,
            fname='{}/test_output'.format(log_path))

        #output_callback_full = LexOutputCallback(vae_vanilla_test_model, valid_dev_input, valid_dev_lex, 1, vocab, delimiter, fname='{}/test_output'.format(log_path))
        #
        # vae_vanilla_train_model.fit_generator(
        #     generator=generate_data_stream(config_data['pretrain_path'], config_data, vocab, config_data['batch_size'], noutputs=3),
        #     steps_per_epoch=steps_per_epoch,
        #     epochs=ceil(config_data['pretrain_samples']/config_data['pretrain_samples_per_epoch']),
        #     callbacks=[output_callback, terminate_on_nan],
        #     validation_data=(valid_input, valid_output[:3]),
        # )

        vae_vanilla_train_model.fit(
            x=train_input,
            y=train_output[:2],
            epochs=config_data['pretrain_epochs'],
            batch_size=batch_size,
            validation_data=(valid_input, valid_output[:2]),
            sample_weight=train_weights[:2],
            callbacks=[output_callback, terminate_on_nan])

        terminate_on_nan = TerminateOnNaN()
        model_checkpoint = ModelCheckpoint(
            'models/vae_model/weights.{epoch:02d}.hdf5',
            period=10,
            save_weights_only=True)

        out_labels = [
            'enc_' + s for s in vae_model_train._get_deduped_metrics_names()
        ]
        out_labels += [
            'dis_' + s
            for s in discriminator_model._get_deduped_metrics_names()
        ]

        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        tensorboard = TensorBoard(log_dir='logging/tensorboard',
                                  histogram_freq=0,
                                  write_grads=True,
                                  write_images=True)
        step_callback = StepCallback(step, steps_per_epoch)
        output_callback = LexOutputCallback(
            vae_vanilla_test_model,
            valid_dev_input,
            valid_dev_lex,
            1,
            vocab,
            delimiter,
            fname='{}/test_output'.format(log_path))
        output_callback_full = LexOutputCallback(
            vae_vanilla_test_model,
            valid_input,
            valid_lex,
            5,
            vocab,
            delimiter,
            fname='{}/test_valid_output'.format(log_path))
        callbacks = CallbackList([
            BaseLogger(),
            ProgbarLogger(count_mode='steps'), step_callback, tensorboard,
            output_callback, output_callback_full, model_checkpoint,
            terminate_on_nan
        ])

        callbacks.set_model(vae_model_train)
        callbacks.set_params({
            'batch_size': batch_size,
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': True,
            'do_validation': True,
            'metrics': callback_metrics or [],
        })

        callbacks.on_train_begin()
        initial_epoch = 0
        num_train_samples = train_input[0].shape[0]
        index_array = np.arange(num_train_samples)

        steps = 0
        epoch = initial_epoch
        while epoch < epochs:
            epoch_logs = {}
            callbacks.on_epoch_begin(epoch)
            index_array = _batch_shuffle(index_array, batch_size)

            steps_done = 0
            batches = _make_batches(num_train_samples, batch_size)
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_logs = {}
                batch_ids = index_array[batch_start:batch_end]
                X = _slice_arrays(train_input, batch_ids)
                y = _slice_arrays(train_output, batch_ids)
                sample_weights = _slice_arrays(train_weights, batch_ids)

                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size

                callbacks.on_batch_begin(batch_index, batch_logs)

                set_trainability(discriminator, trainable=False)
                enc_outs = vae_model_train.train_on_batch(
                    x=X, y=y, sample_weight=sample_weights)

                set_trainability(discriminator, trainable=True)
                list_disc_loss_real = []
                if steps < 25 or steps % 500 == 0:
                    disc_iterations = 25
                else:
                    disc_iterations = discriminator_iterations
                for disc_it in range(disc_iterations):
                    real_idx = np.random.choice(train_input[0].shape[0],
                                                len(batch_ids),
                                                replace=False)

                    disX_train = train_input[-1][
                        real_idx]  #take input 8 as train input and the rest as targets
                    disy_train = [x[real_idx] for x in train_input[:8]
                                  ]  #take input 1-7 as targets

                    #train on real data
                    dis_outs_real = discriminator_model.train_on_batch(
                        disX_train, disy_train)

                    list_disc_loss_real.append(dis_outs_real)

                loss_d_real = np.mean(list_disc_loss_real, axis=0)

                outs = np.concatenate((enc_outs, loss_d_real))

                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)
                epoch_logs = {}
                batch_index += 1
                steps_done += 1
                steps += 1
                # Epoch finished.
                if steps_done >= steps_per_epoch:
                    valid_len = valid_output[0].shape[0]
                    enc_val_outs = vae_model_train.evaluate(valid_input,
                                                            valid_output,
                                                            verbose=False)
                    dis_val_outs = discriminator_model.evaluate(
                        valid_input[-1], valid_input[:8], verbose=False)

                    val_outs = enc_val_outs + dis_val_outs

                    #val_outs = full_model.evaluate(valid_input, valid_output, verbose=False)

                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1

        callbacks.on_train_end()
Beispiel #8
0
def train_wgan_with_grad_penalty(prior_gen,
                                 generator,
                                 data_gen,
                                 critic,
                                 batch_size,
                                 epochs,
                                 batches_per_epoch=100,
                                 optimizer=Adam(lr=1e-4, beta_1=0, beta_2=0.9),
                                 grad_pen_coef=10.,
                                 critic_gen_train_ratio=2,
                                 callbacks=None):
    # build model to train the critic
    data_shape = critic.input_shape[1:]
    real_critic_input = Input(shape=data_shape, name='real_in')
    fake_critic_input = Input(shape=data_shape, name='fake_in')
    interp_critic_input = Input(shape=data_shape, name='interp_in')

    real_critic_score = critic(real_critic_input)
    fake_critic_score = critic(fake_critic_input)
    interp_critic_score = critic(interp_critic_input)

    critic_loss = subtract([fake_critic_score, real_critic_score])
    gradient_penalty = GradPenLayer()(
        [interp_critic_input, interp_critic_score])

    critic_train_mdl = Model(
        [real_critic_input, fake_critic_input, interp_critic_input],
        [critic_loss, gradient_penalty])

    critic_train_mdl.compile(optimizer=optimizer,
                             loss=lambda y_true, y_pred: y_pred,
                             loss_weights=[1., grad_pen_coef])

    # build model to train generator
    prior_input = Input(shape=generator.input_shape[1:], name='prior_in')
    critic.trainable = False
    critic_on_generator_score = critic(generator(prior_input))
    generator_train_mdl = Model(prior_input, critic_on_generator_score)
    generator_train_mdl.compile(optimizer=optimizer,
                                loss=lambda y_true, y_pred: -y_pred)

    # init callbacks
    callbacks = callbacks or []
    callbacks = CallbackList(callbacks)
    callbacks.set_model({'generator': generator, 'critic': critic})
    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': epochs,
        'steps': batches_per_epoch,
        'samples': batches_per_epoch * batch_size,
        'prior_gen': prior_gen,
        'data_gen': data_gen,
    })

    # train
    print('Training on {} samples for {} epochs'.format(
        batches_per_epoch * batch_size, epochs))
    callbacks.on_train_begin()
    for e in range(epochs):
        print('Epoch {}/{}'.format(e + 1, epochs))
        callbacks.on_epoch_begin(e)
        progbar = Progbar(target=batches_per_epoch * batch_size)
        dummy_y = np.array([None] * batch_size)
        for b in range(batches_per_epoch):
            callbacks.on_batch_begin(b)
            batch_losses = np.zeros(shape=3)
            for critic_upd in range(critic_gen_train_ratio):
                real_batch = data_gen(batch_size)
                fake_batch = generator.predict(prior_gen(batch_size))
                weights = np.random.uniform(size=batch_size)
                weights = weights.reshape((-1, ) + (1, ) *
                                          (len(real_batch.shape) - 1))
                interp_batch = weights * real_batch + (1. -
                                                       weights) * fake_batch

                x_batch = {
                    'real_in': real_batch,
                    'fake_in': fake_batch,
                    'interp_in': interp_batch
                }
                cur_losses = np.array(
                    critic_train_mdl.train_on_batch(x=x_batch,
                                                    y=[dummy_y, dummy_y]))
                batch_losses += cur_losses

            generator_train_mdl.train_on_batch(x=prior_gen(batch_size),
                                               y=dummy_y)

            losses_names = ('total_loss', 'critic_loss', 'gradient_pen')
            progbar.add(batch_size, zip(losses_names, batch_losses))
            callbacks.on_batch_end(b)

        progbar.update(batches_per_epoch * batch_size)
        callbacks.on_epoch_end(e)

    callbacks.on_train_end()
Beispiel #9
0
def main(args):

    try:
        opts, args = getopt.getopt(args, "c:s", ["config="])
    except getopt.GetoptError:
        print('usage: -c config.json')
        sys.exit(2)

    start_from_model = False
    for opt, arg in opts:
        if opt in ("-c", "--config"):
            config_fname = os.path.join('configurations', arg)
        elif opt == '-s':
            start_from_model = True

    if start_from_model:
        filemode = 'a'
    else:
        filemode = 'w'

    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                        level=logging.INFO,
                        filename='logging/vae_gan/evolution.log',
                        filemode=filemode)

    with open(config_fname, 'r') as json_data:
        config_data = json.load(json_data)

        tweets_path = config_data['tweets_path']
        vocab_path = config_data['vocab_path']
        vocab = cPickle.load(open(join(vocab_path, 'vocabulary.pkl'), 'rb'))

        #== == == == == == =
        # Load all the Data
        #== == == == == == =

        noutputs = 5

        logging.info('Load Training Data')
        train_input, train_output = load_data(
            join(tweets_path, 'en_train.tsv'), config_data, vocab, noutputs)
        logging.info('Load Validation Data')
        valid_input, valid_output = load_data(
            join(tweets_path, 'en_valid15.tsv'), config_data, vocab, noutputs)
        logging.info('Load Validation Data')
        valid_input2, valid_output2 = load_data(
            join(tweets_path, 'en_test16.tsv'), config_data, vocab, noutputs)
        logging.info('Load Test Data')
        test_input, test_output = load_data(join(tweets_path, 'en_test17.tsv'),
                                            config_data, vocab, noutputs)

        step = K.variable(1.)

        # == == == == == == == == == == =
        # Define and load the CNN model
        # == == == == == == == == == == =
        full_model, encoding_train_model, decoder_train_model, discriminator_train_model, decoder_inference, encoder, decoder, discriminator, discriminator_pretrain_model = vae_gan_model(
            config_data, vocab, step)
        #full_model.summary()
        encoding_train_model.summary()
        decoder_train_model.summary()
        discriminator_train_model.summary()
        decoder_inference.summary()
        encoder.summary()
        decoder.summary()
        discriminator.summary()

        #pretrain_discriminator(discriminator_pretrain_model, train_input, vocab)

        model_path = 'models/vae_model/'
        steps_per_epoch = int(
            ceil(config_data['samples_per_epoch'] / config_data['batch_size']))
        epochs = int(
            ceil(config_data['nb_epochs'] *
                 (config_data['nsamples'] / config_data['samples_per_epoch'])))
        batch_size = config_data['batch_size']

        initial_epoch = 0
        skip_texts = 0

        terminate_on_nan = TerminateOnNaN()
        model_checkpoint = ModelCheckpoint(
            'models/vae_model/weights.{epoch:02d}.hdf5',
            period=10,
            save_weights_only=True)

        generator = generate_data_stream(config_data['training_path'],
                                         config_data,
                                         vocab,
                                         config_data['batch_size'],
                                         skip_data=skip_texts,
                                         noutputs=noutputs)
        enqueuer = GeneratorEnqueuer(generator,
                                     use_multiprocessing=False,
                                     wait_time=0.01)
        enqueuer.start(workers=1, max_queue_size=10)
        output_generator = enqueuer.get()

        enc_out_labels = [
            'enc_' + s
            for s in encoding_train_model._get_deduped_metrics_names()
        ]
        dec_out_labels = [
            'dec_' + s
            for s in decoder_train_model._get_deduped_metrics_names()
        ]
        dis_out_labels = [
            'dis_' + s
            for s in discriminator_train_model._get_deduped_metrics_names()
        ]
        out_labels = enc_out_labels + dec_out_labels + dis_out_labels

        #out_labels = full_model._get_deduped_metrics_names()

        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        step_callback = NewCallback(step, steps_per_epoch)
        output_callback = OutputCallback(decoder_inference, valid_input, 15,
                                         vocab, '')
        callbacks = CallbackList([
            BaseLogger(),
            ProgbarLogger(count_mode='steps'), step_callback, output_callback
        ])

        callbacks.set_model(full_model)
        callbacks.set_params({
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': True,
            'do_validation': True,
            'metrics': callback_metrics,
        })

        callbacks.on_train_begin()

        epoch = initial_epoch
        while epoch < epochs:
            epoch_logs = {}
            callbacks.on_epoch_begin(epoch)
            steps_done = 0
            batch_index = 0
            while steps_done < steps_per_epoch:
                batch_logs = {}

                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size

                X, y = next(output_generator)

                callbacks.on_batch_begin(batch_index, batch_logs)

                set_trainability(encoder, trainable=True)
                set_trainability(decoder, trainable=False)
                set_trainability(discriminator, trainable=False)
                enc_outs = encoding_train_model.train_on_batch(X, y[:3])

                set_trainability(encoder, trainable=False)
                set_trainability(decoder, trainable=True)
                set_trainability(discriminator, trainable=False)
                dec_outs = decoder_train_model.train_on_batch(X, y[:4])

                set_trainability(encoder, trainable=False)
                set_trainability(decoder, trainable=False)
                set_trainability(discriminator, trainable=True)

                dis_outs = discriminator_train_model.train_on_batch(X, y[0])
                outs = enc_outs + dec_outs + [dis_outs]

                #outs = full_model.train_on_batch(X, y)

                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)
                epoch_logs = {}
                batch_index += 1
                steps_done += 1

                # Epoch finished.
                if steps_done >= steps_per_epoch:
                    enc_val_outs = encoding_train_model.evaluate(
                        valid_input, valid_output[:3], verbose=False)
                    dec_val_outs = decoder_train_model.evaluate(
                        valid_input, valid_output[:4], verbose=False)
                    dis_val_outs = discriminator_train_model.evaluate(
                        valid_input, valid_output[0], verbose=False)

                    val_outs = enc_val_outs + dec_val_outs + [dis_val_outs]

                    #val_outs = full_model.evaluate(valid_input, valid_output, verbose=False)

                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1

        callbacks.on_train_end()
Beispiel #10
0
def main(args):

    try:
        opts, args = getopt.getopt(args, "c:s", ["config="])
    except getopt.GetoptError:
        print('usage: -c config.json')
        sys.exit(2)

    start_from_model = False
    for opt, arg in opts:
        if opt in ("-c", "--config"):
            config_fname = os.path.join('configurations', arg)
        elif opt == '-s':
            start_from_model = True

    if start_from_model:
        filemode = 'a'
    else:
        filemode = 'w'

    log_path = 'logging/vae_nlg_{}'.format(int(round(time.time() * 1000)))
    os.mkdir(log_path)

    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                        level=logging.INFO,
                        filename='{}/evolution.log'.format(log_path),
                        filemode=filemode)

    with open(config_fname, 'r') as json_data:
        config_data = json.load(json_data)

        batch_size = config_data['batch_size']
        epochs = config_data['nb_epochs']
        discriminator_iterations = config_data['discriminator_iterations']
        tweets_path = config_data['tweets_path']
        vocab_path = config_data['vocab_path']
        vocab = cPickle.load(open(join(vocab_path, 'vocabulary.pkl'), 'rb'))

        #== == == == == == =
        # Load all the Data
        #== == == == == == =

        noutputs = 5

        logging.info('Load Training Data')
        train_input, train_output = load_text_pairs(
            join(tweets_path, 'training_set.tsv'), config_data, vocab,
            noutputs)
        logging.info('Load Validation Data')
        valid_input, valid_output = load_text_pairs(
            join(tweets_path, 'vaild_set.tsv'), config_data, vocab, noutputs)
        logging.info('Load Output Validation Data')
        valid_dev_input, valid_dev_output = load_text_pairs(
            join(tweets_path, 'test_set.tsv'), config_data, vocab, noutputs)

        #train_input = [x[:1213] for x in train_input]
        #train_output = [x[:1213] for x in train_output]

        noise_valid_input = np.zeros(shape=(valid_input[0].shape[0],
                                            config_data['z_size']))

        step = K.variable(1.)
        steps_per_epoch = ceil(train_output[0].shape[0] /
                               config_data['batch_size'])
        # == == == == == == == == == == =
        # Define and load the CNN model
        # == == == == == == == == == == =
        vae_model, vae_model_test, decoder_discr_model, decoder_test_model, discriminator_model, discriminator = get_vae_gan_model(
            config_data, vocab, step)
        with open(os.path.join(log_path, 'models.txt'), 'wt') as fh:
            fh.write('VAE Model\n')
            fh.write('---------\n')
            vae_model.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('VAE Model Test\n')
            fh.write('--------------\n')
            vae_model_test.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('Decoder Discriminator Model\n')
            fh.write('---------------------------\n')
            decoder_discr_model.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('Decoder Test Model\n')
            fh.write('---------------------------\n')
            decoder_test_model.summary(print_fn=lambda x: fh.write(x + '\n'))
            fh.write('Discriminator Model\n')
            fh.write('-------------------\n')
            discriminator_model.summary(print_fn=lambda x: fh.write(x + '\n'))

        terminate_on_nan = TerminateOnNaN()
        model_checkpoint = ModelCheckpoint(
            'models/vae_model/weights.{epoch:02d}.hdf5',
            period=10,
            save_weights_only=True)

        enc_out_labels = [
            'enc_' + s for s in vae_model._get_deduped_metrics_names()
        ]
        dec_out_labels = [
            'dec_' + s
            for s in decoder_discr_model._get_deduped_metrics_names()
        ]
        dis_out_labels = [
            'dis_' + s
            for s in discriminator_model._get_deduped_metrics_names()
        ]
        out_labels = enc_out_labels + dec_out_labels + [
            'dis_real', 'dis_gen', 'dis_noise'
        ]

        #out_labels = full_model._get_deduped_metrics_names()

        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        step_callback = StepCallback(step, steps_per_epoch)
        output_callback = GANOutputCallback(
            vae_model_test,
            valid_dev_input[0],
            1,
            vocab,
            '',
            fname='{}/test_output'.format(log_path))
        callbacks = CallbackList([
            BaseLogger(),
            ProgbarLogger(count_mode='steps'), step_callback, output_callback,
            model_checkpoint, terminate_on_nan
        ])

        callbacks.set_model(vae_model_test)
        callbacks.set_params({
            'batch_size': batch_size,
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': True,
            'do_validation': True,
            'metrics': callback_metrics or [],
        })

        callbacks.on_train_begin()
        initial_epoch = 0
        num_train_samples = train_input[0].shape[0]
        index_array = np.arange(num_train_samples)

        steps = 0
        epoch = initial_epoch
        while epoch < epochs:
            epoch_logs = {}
            callbacks.on_epoch_begin(epoch)
            index_array = _batch_shuffle(index_array, batch_size)

            steps_done = 0
            batches = _make_batches(num_train_samples, batch_size)
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_logs = {}
                batch_ids = index_array[batch_start:batch_end]
                X = _slice_arrays(train_input, batch_ids)
                y = _slice_arrays(train_output, batch_ids)

                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size

                callbacks.on_batch_begin(batch_index, batch_logs)

                set_trainability(discriminator, trainable=False)
                enc_outs = vae_model.train_on_batch(x=X, y=y[:3])

                set_trainability(discriminator, trainable=True)
                list_disc_loss_real = []
                list_disc_loss_gen = []
                list_disc_loss_noise = []
                if steps < 25 or steps % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = discriminator_iterations
                noise_input = np.zeros(shape=(len(batch_ids),
                                              config_data['z_size']))
                for disc_it in range(disc_iterations):
                    #clip_weights(discriminator)
                    real_idx = np.random.choice(train_input[0].shape[0],
                                                len(batch_ids),
                                                replace=False)
                    train_real_batch = [x[real_idx] for x in train_input]

                    #train on real data
                    x_fake = vae_model_test.predict_on_batch(
                        x=train_real_batch[0])
                    x_noise_fake = decoder_test_model.predict_on_batch(
                        x=noise_input)

                    train_input_discr = np.concatenate(
                        (train_real_batch[0], train_real_batch[0],
                         train_real_batch[0]))
                    train_output_discr = np.concatenate(
                        (train_real_batch[1], x_fake, x_noise_fake))
                    labels = np.asarray(
                        len(batch_ids) * [1] + 2 * len(batch_ids) * [-1])

                    index_array_discr = np.arange(len(labels))
                    np.random.shuffle(index_array_discr)

                    discr_batch = [
                        train_input_discr[index_array_discr],
                        train_output_discr[index_array_discr]
                    ]
                    discr_batch_labels = labels[index_array_discr]

                    dis_outs_real = discriminator_model.train_on_batch(
                        discr_batch, discr_batch_labels)
                    #dis_outs_real = discriminator_model.train_on_batch(train_real_batch, -np.ones(shape=(len(batch_ids), 1)))
                    #dis_outs_gen = discriminator_model.train_on_batch([train_real_batch[0], x_fake], np.ones(shape=(len(batch_ids), 1)))
                    #dis_outs_gen_noise = discriminator_model.train_on_batch([train_real_batch[0], x_noise_fake], np.ones(shape=(len(batch_ids), 1)))

                    list_disc_loss_real.append(dis_outs_real)
                    #list_disc_loss_gen.append(dis_outs_gen)
                    #list_disc_loss_noise.append(dis_outs_gen_noise)

                loss_d_real = -np.mean(list_disc_loss_real)
                loss_d_gen = np.mean(list_disc_loss_gen)
                loss_d_noise = np.mean(list_disc_loss_noise)

                set_trainability(discriminator, trainable=False)

                decoder_discr_input = [X[0], noise_input]
                dec_outs = decoder_discr_model.train_on_batch(
                    x=decoder_discr_input,
                    y=-np.ones(shape=(len(batch_ids), 1)))

                outs = enc_outs + [dec_outs] + [loss_d_real]

                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)
                epoch_logs = {}
                batch_index += 1
                steps_done += 1
                steps += 1
                # Epoch finished.
                if steps_done >= steps_per_epoch:
                    valid_len = valid_output[0].shape[0]
                    enc_val_outs = vae_model.evaluate(valid_input,
                                                      valid_output[:3],
                                                      verbose=False)
                    dec_val_outs = decoder_discr_model.evaluate(
                        [valid_input[0], noise_valid_input],
                        -np.ones(shape=(valid_len, 1)),
                        verbose=False)
                    dis_val_outs = discriminator_model.evaluate(
                        valid_input,
                        -np.ones(shape=(valid_len, 1)),
                        verbose=False)

                    val_outs = enc_val_outs + [dec_val_outs] + [dis_val_outs]

                    #val_outs = full_model.evaluate(valid_input, valid_output, verbose=False)

                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1

        callbacks.on_train_end()
Beispiel #11
0
def GanTrain(discriminator,
             generator,
             opt,
             global_batch_size,
             warmup_epochs,
             datapath,
             EventsperFile,
             nEvents,
             WeightsDir,
             mod=0,
             nb_epochs=30,
             batch_size=128,
             latent_size=128,
             gen_weight=6,
             aux_weight=0.2,
             ecal_weight=0.1,
             lr=0.001,
             rho=0.9,
             decay=0.0,
             g_weights='params_generator_epoch_',
             d_weights='params_generator_epoch_',
             xscale=1,
             verbose=True):
    start_init = time.time()
    # verbose = False
    if hvd.rank() == 0:
        print('[INFO] Building discriminator')
    #discriminator.summary()
    discriminator.compile(optimizer=opt,
                          loss=[
                              'binary_crossentropy',
                              'mean_absolute_percentage_error',
                              'mean_absolute_percentage_error'
                          ],
                          loss_weights=[gen_weight, aux_weight, ecal_weight])

    # build the generator
    if hvd.rank() == 0:
        print('[INFO] Building generator')
    #generator.summary()
    generator.compile(optimizer=opt, loss='binary_crossentropy')

    # build combined Model
    latent = Input(shape=(latent_size, ), name='combined_z')
    fake_image = generator(latent)
    discriminator.trainable = False
    fake, aux, ecal = discriminator(fake_image)
    combined = Model(input=[latent],
                     output=[fake, aux, ecal],
                     name='combined_model')

    # Getting Data
    Trainfiles, Testfiles = DivideFiles(datapath,
                                        nEvents=nEvents,
                                        EventsperFile=EventsperFile,
                                        datasetnames=["ECAL"],
                                        Particles=["Ele"])

    if hvd.rank() == 0:
        print("Train files: {0} \nTest files: {1}".format(
            Trainfiles, Testfiles))

    #Read test data into a single array
    for index, dtest in enumerate(Testfiles):
        if index == 0:
            X_test, Y_test, ecal_test = GetData(dtest)
        else:
            X_temp, Y_temp, ecal_temp = GetData(dtest)
            X_test = np.concatenate((X_test, X_temp))
            Y_test = np.concatenate((Y_test, Y_temp))
            ecal_test = np.concatenate((ecal_test, ecal_temp))

    for index, dtrain in enumerate(Trainfiles):
        if index == 0:
            X_train, Y_train, ecal_train = GetData(dtrain)
        else:
            X_temp, Y_temp, ecal_temp = GetData(dtrain)
            X_train = np.concatenate((X_train, X_temp))
            Y_train = np.concatenate((Y_train, Y_temp))
            ecal_train = np.concatenate((ecal_train, ecal_temp))

    nb_test = X_test.shape[0]
    assert X_train.shape[0] == EventsperFile * len(
        Trainfiles), "# Total events in training files"
    nb_train = X_train.shape[0]  # Total events in training files
    total_batches = nb_train / global_batch_size
    if hvd.rank() == 0:
        print('Total Training batches = {} with {} events'.format(
            total_batches, nb_train))

    combined.compile(
        #optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
        optimizer=opt,
        loss=[
            'binary_crossentropy', 'mean_absolute_percentage_error',
            'mean_absolute_percentage_error'
        ],
        loss_weights=[gen_weight, aux_weight, ecal_weight])

    gcb = CallbackList( \
        callbacks=[ \
        hvd.callbacks.BroadcastGlobalVariablesCallback(0), \
        hvd.callbacks.MetricAverageCallback(), \
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=warmup_epochs, verbose=1, steps_per_epoch=total_batches), \
        hvd.callbacks.LearningRateScheduleCallback(start_epoch=warmup_epochs, end_epoch=nb_epochs, multiplier=1.), \
        keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1) \
        ])

    dcb = CallbackList( \
        callbacks=[ \
        hvd.callbacks.BroadcastGlobalVariablesCallback(0), \
        hvd.callbacks.MetricAverageCallback(), \
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=warmup_epochs, verbose=1, steps_per_epoch=total_batches), \
        hvd.callbacks.LearningRateScheduleCallback(start_epoch=warmup_epochs, end_epoch=nb_epochs, multiplier=1.), \
        keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1) \
        ])

    ccb = CallbackList( \
        callbacks=[ \
        hvd.callbacks.BroadcastGlobalVariablesCallback(0), \
        hvd.callbacks.MetricAverageCallback(), \
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=warmup_epochs, verbose=1, steps_per_epoch=total_batches), \
        hvd.callbacks.LearningRateScheduleCallback(start_epoch=warmup_epochs, end_epoch=nb_epochs, multiplier=1.), \
        keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1) \
        ])

    gcb.set_model(generator)
    dcb.set_model(discriminator)
    ccb.set_model(combined)

    gcb.on_train_begin()
    dcb.on_train_begin()
    ccb.on_train_begin()

    print("On hostname {0} - After init using {1} memory".format(
        socket.gethostname(),
        psutil.Process(os.getpid()).memory_info()[0]))

    train_history = defaultdict(list)
    test_history = defaultdict(list)

    if hvd.rank() == 0:
        print('Initialization time was {} seconds'.format(time.time() -
                                                          start_init))

    for epoch in range(nb_epochs):
        epoch_start = time.time()
        if hvd.rank() == 0:
            print('Epoch {} of {}'.format(epoch + 1, nb_epochs))

        randomize(X_train, Y_train, ecal_train)

        epoch_gen_loss = []
        epoch_disc_loss = []

        image_batches = genbatches(X_train, batch_size)
        energy_batches = genbatches(Y_train, batch_size)
        ecal_batches = genbatches(ecal_train, batch_size)

        for index in range(total_batches):
            start = time.time()
            image_batch = image_batches.next()
            energy_batch = energy_batches.next()
            ecal_batch = ecal_batches.next()

            noise = np.random.normal(0, 1, (batch_size, latent_size))
            sampled_energies = np.random.uniform(0.1, 5, (batch_size, 1))
            generator_ip = np.multiply(sampled_energies, noise)
            # ecal sum from fit
            ecal_ip = GetEcalFit(sampled_energies, mod, xscale)
            generated_images = generator.predict(generator_ip, verbose=0)
            real_batch_loss = discriminator.train_on_batch(
                image_batch,
                [BitFlip(np.ones(batch_size)), energy_batch, ecal_batch])
            fake_batch_loss = discriminator.train_on_batch(
                generated_images,
                [BitFlip(np.zeros(batch_size)), sampled_energies, ecal_ip])
            epoch_disc_loss.append([
                (a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)
            ])

            trick = np.ones(batch_size)
            gen_losses = []
            for _ in range(2):
                noise = np.random.normal(0, 1, (batch_size, latent_size))
                sampled_energies = np.random.uniform(0.1, 5, (batch_size, 1))
                generator_ip = np.multiply(sampled_energies, noise)
                ecal_ip = GetEcalFit(sampled_energies, mod, xscale)
                gen_losses.append(
                    combined.train_on_batch(
                        [generator_ip],
                        [trick,
                         sampled_energies.reshape((-1, 1)), ecal_ip]))
            epoch_gen_loss.append([(a + b) / 2 for a, b in zip(*gen_losses)])

            if (index % 1) == 0 and hvd.rank() == 0:
                # progress_bar.update(index)
                print('processed {}/{} batches in {}'.format(
                    index + 1, total_batches,
                    time.time() - start))

        # save weights every epoch
        if hvd.rank() == 0:

            safe_mkdir(WeightsDir)

            print("saving weights of gen")
            generator.save_weights(
                WeightsDir +
                '/generator_{0}{1:03d}.hdf5'.format(g_weights, epoch),
                overwrite=True)

            print("saving weights of disc")
            discriminator.save_weights(
                WeightsDir +
                '/discriminator_{0}{1:03d}.hdf5'.format(d_weights, epoch),
                overwrite=True)

            epoch_time = time.time() - epoch_start
            print("The {} epoch took {} seconds".format(epoch, epoch_time))
    do_validation = True

    callbacks._set_model(model)
    callbacks._set_params({
        'batch_size': batch_size,
        'nb_epoch': nb_epoch,
        'nb_sample': nb_train_sample,
        'verbose': 1,
        'do_validation': do_validation,
        'metrics': metrics,
    })

    ##########################
    # TRAINING
    ##########################
    callbacks.on_train_begin()

    model.stop_training = False
    for epoch in range(nb_epoch):
        callbacks.on_epoch_begin(epoch)

        if shuffle_on_epoch_start:
            X_train, y_train = util.shuffle_data(X_train, y_train)

        # train
        util.train_on_batch(model, X_train, y_train, nb_classes,
                            callbacks=callbacks,
                            normalize=normalize_data,
                            batch_size=batch_size,
                            class_weight=class_weight,
                            shuffle=False)
Beispiel #13
0
    def fit_generator(self,
                      generator,
                      n_steps_per_epoch,
                      n_epochs=1,
                      validation_data=None,
                      n_validation_steps=None,
                      callbacks=None):
        """Train the network on batches of data generated from `generator`

        :param generator: a generator yielding batches indefinitely, where each
         batch is a tuple of (inputs, targets)
        :type generator: generator
        :param n_steps_per_epoch: number of batches to train on in one epoch
        :type n_steps_per_epoch: int
        :param n_epochs: number of epochs to train the model
        :type n_epochs: int
        :param validation_data: generator yielding batches to evaluate the loss
         on at the end of each epoch, where each batch is a tuple of (inputs,
         targets)
        :type validation_data: generator
        :param n_validation_steps: number of batches to evaluate on from
         `validation_data`
        :param callbacks: callbacks to be used during training
        :type callbacks: list[object]
        :raises RuntimeError: if only one of `validation_data` and
         `n_validation_steps` are passed in
        """

        default_callbacks = self._load_default_callbacks()
        default_callbacks.append(ProgbarLogger(count_mode='steps'))
        if callbacks:
            default_callbacks.extend(callbacks)
        callbacks = CallbackList(default_callbacks)

        self._assert_compiled()

        invalid_inputs = ((validation_data is not None
                           and not n_validation_steps)
                          or (n_validation_steps and validation_data is None))
        if invalid_inputs:
            msg = ('`validation_data` and `n_validation_steps` must both be '
                   'passed, or neither.')
            raise RuntimeError(msg)

        if self.device:
            self.network.to(self.device)

        metrics = ['loss']
        if self.n_outputs > 1:
            for idx_output in range(1, self.n_outputs + 1):
                metrics.append('loss{}'.format(idx_output))
        if validation_data is not None:
            metrics.append('val_loss')
            if self.n_outputs > 1:
                for idx_output in range(1, self.n_outputs + 1):
                    metrics.append('val_loss{}'.format(idx_output))
        for metric_name in self.metric_names:
            metrics.append(metric_name)
            if validation_data is not None:
                metrics.append('val_{}'.format(metric_name))

        callbacks.set_params({
            'epochs': n_epochs,
            'metrics': metrics,
            'steps': n_steps_per_epoch,
            'verbose': True
        })
        callbacks.set_model(self)

        callbacks.on_train_begin()
        for idx_epoch in range(n_epochs):
            if self.stop_training:
                break

            epoch_logs = {}
            callbacks.on_epoch_begin(idx_epoch)

            for idx_batch in range(n_steps_per_epoch):
                batch_logs = {'batch': idx_batch, 'size': 1}
                callbacks.on_batch_begin(idx_batch, batch_logs)

                generator_output = next(generator)
                if len(generator_output) != 2:
                    msg = ('Output of generator should be a tuple of '
                           '(inputs, targets), but instead got a {}: '
                           '{}.').format(type(generator_output),
                                         str(generator_output))
                inputs, targets = generator_output
                train_outputs = self.train_on_batch(inputs, targets)

                batch_logs['loss'] = train_outputs[0]
                if self.n_outputs > 1:
                    for idx_output in range(1, self.n_outputs + 1):
                        batch_logs['loss{}'.format(idx_output)] = (
                            train_outputs[idx_output])

                idx_metric_values = (1 if self.n_outputs == 1 else
                                     self.n_outputs + 1)
                it = zip(self.metric_names, train_outputs[idx_metric_values:])
                for metric_name, train_output in it:
                    batch_logs[metric_name] = train_output
                callbacks.on_batch_end(idx_batch, batch_logs)

                if self.stop_training:
                    break

            if validation_data:
                val_outputs = self.evaluate_generator(validation_data,
                                                      n_validation_steps)

                epoch_logs['val_loss'] = val_outputs[0]
                if self.n_outputs > 1:
                    for idx_output in range(1, self.n_outputs + 1):
                        epoch_logs['val_loss{}'.format(idx_output)] = (
                            val_outputs[idx_output])

                idx_metric_values = (1 if self.n_outputs == 1 else
                                     self.n_outputs + 1)
                it = zip(self.metric_names, val_outputs[idx_metric_values:])
                for metric_name, val_output in it:
                    metric_name = 'val_{}'.format(metric_name)
                    epoch_logs[metric_name] = val_output
            callbacks.on_epoch_end(idx_epoch, epoch_logs)
        callbacks.on_train_end()
Beispiel #14
0
def predict_image(version, image_path, batch_size, overlap, data_format=None):
    def current_time_millis():
        return int(round(time.time() * 1000))

    def offset(size, diff, overlap):
        return math.floor(diff / math.ceil(diff / (size * (1 - overlap))))

    def map_c(i, j, b, l):
        return int(((i * b) + j) / l)

    def map_r(i, j, b, l):
        return ((i * b) + j) % l

    if data_format is None:
        data_format = K.image_data_format()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format:', data_format)

    path = version.model_file.name
    print(_('Loading model "%(path)s".') % {'path': path})
    model = load_model(os.path.join(settings.MEDIA_ROOT, path))

    if len(model.inputs) != 1:
        raise RuntimeError('Models with more than one input are not'
                           ' supported at the moment.')

    inputs = []
    for i in range(len(model.inputs)):
        name = model.inputs[i].name
        pos = min(
            name.index('/') if '/' in name else len(name),
            name.index(':') if ':' in name else len(name))
        name = name[:pos]

        inputs.append({'shape': model.inputs[i].shape.as_list(), 'name': name})
        if data_format == 'channels_first':
            inputs[i]['grayscale'] = inputs[i]['shape'][1] == 1
            inputs[i]['r'] = inputs[i]['shape'][2]
            inputs[i]['c'] = inputs[i]['shape'][3]
        elif data_format == 'channels_last':
            inputs[i]['r'] = inputs[i]['shape'][1]
            inputs[i]['c'] = inputs[i]['shape'][2]
            inputs[i]['grayscale'] = inputs[i]['shape'][3] == 1

        inputs[i]['img'] = img_to_array(
            load_img(image_path, inputs[i]['grayscale']))
        inputs[i]['img'] *= 1. / 255
        if data_format == 'channels_first':
            inputs[i]['img_r'] = inputs[i]['img'].shape[1]
            inputs[i]['img_c'] = inputs[i]['img'].shape[2]
        elif data_format == 'channels_last':
            inputs[i]['img_r'] = inputs[i]['img'].shape[0]
            inputs[i]['img_c'] = inputs[i]['img'].shape[1]

        inputs[i]['diff_r'] = inputs[i]['img_r'] - inputs[i]['r']
        inputs[i]['diff_c'] = inputs[i]['img_c'] - inputs[i]['c']
        inputs[i]['offset_r'] = offset(inputs[i]['r'], inputs[i]['diff_r'],
                                       overlap)
        inputs[i]['offset_c'] = offset(inputs[i]['c'], inputs[i]['diff_c'],
                                       overlap)
        inputs[i]['nb_r'] = math.ceil(
            inputs[i]['diff_r'] / inputs[i]['offset_r']) + 1
        inputs[i]['nb_c'] = math.ceil(
            inputs[i]['diff_c'] / inputs[i]['offset_c']) + 1
    inputs = inputs[0]
    N = inputs['nb_r'] * inputs['nb_c']
    steps = math.ceil(N / batch_size)

    metrics = []
    outputs = []
    for i in range(len(model.outputs)):
        tshape = model.outputs[i].shape.as_list()
        name = model.outputs[i].name
        pos = min(
            name.index('/') if '/' in name else len(name),
            name.index(':') if ':' in name else len(name))
        name = name[:pos]
        activation = model.get_layer(name).activation.__name__.lower()
        outputs.append({'name': name, 'shape': tshape})

        if len(tshape) == 2:
            if activation == 'softmax':
                outputs[i]['t'] = 'class'
            else:
                outputs[i]['t'] = 'multi'

            nb_classes = tshape[1]
            if nb_classes is None:
                nb_classes = model.get_layer(name).output_shape[1]
            nb_classes = int(nb_classes)
            metrics += ['%s:%s' % (name, i) for i in range(nb_classes)]

            if data_format == 'channels_first':
                shape = (nb_classes, inputs['nb_r'], inputs['nb_c'])
            elif data_format == 'channels_last':
                shape = (inputs['nb_r'], inputs['nb_c'], nb_classes)

        elif len(tshape) == 4:
            if activation == 'softmax':
                outputs[i]['t'] = 'class'
            else:
                outputs[i]['t'] = 'img'

            shape = (inputs['nb_r'], inputs['nb_c']) + tuple(tshape[1:])
        outputs[i]['p'] = np.zeros(shape)

    history = History()
    callbacks = CallbackList([BaseLogger(), history, ProgbarLogger()])
    callbacks.set_model(model)
    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': 1,
        'steps': steps,
        'samples': N,
        'verbose': 1,
        'do_validation': False,
        'metrics': metrics,
    })

    callbacks.on_train_begin()
    callbacks.on_epoch_begin(0)
    start_time = current_time_millis()
    for b in range(steps):
        current_index = (b * batch_size) % N
        if N >= current_index + batch_size:
            current_batch_size = batch_size
        else:
            current_batch_size = N - current_index

        batch_logs = {'batch': b, 'size': current_batch_size}
        for metric in metrics:
            batch_logs[metric] = 0
        callbacks.on_batch_begin(b, batch_logs)

        bX = np.zeros((current_batch_size, ) + tuple(inputs['shape'][1:]))
        for j in range(current_batch_size):
            idx_r = map_r(b, j, batch_size, inputs['nb_r'])
            idx_c = map_c(b, j, batch_size, inputs['nb_r'])
            top = min(idx_r * inputs['offset_r'],
                      inputs['img_r'] - inputs['r'])
            bottom = min(idx_r * inputs['offset_r'] + inputs['r'],
                         inputs['img_r'])
            left = min(idx_c * inputs['offset_c'],
                       inputs['img_c'] - inputs['c'])
            right = min(idx_c * inputs['offset_c'] + inputs['c'],
                        inputs['img_c'])

            if data_format == 'channels_first':
                bX[j] = inputs['img'][:, top:bottom, left:right]
            elif data_format == 'channels_last':
                bX[j] = inputs['img'][top:bottom, left:right, :]

        p = model.predict_on_batch(bX)
        if type(p) != list:
            p = [p]
        for j in range(current_batch_size):
            for i in range(len(outputs)):
                idx_r = map_r(b, j, batch_size, inputs['nb_r'])
                idx_c = map_c(b, j, batch_size, inputs['nb_r'])

                if len(outputs[i]['p'].shape) == 3:
                    if data_format == 'channels_first':
                        outputs[i]['p'][:, idx_r, idx_c] = p[i][j]
                    elif data_format == 'channels_last':
                        outputs[i]['p'][idx_r, idx_c, :] = p[i][j]
                    metric = metrics[p[i][j].argmax()]
                    batch_logs[metric] += 1. / current_batch_size
                elif len(outputs[i]['p'].shape) == 5:
                    outputs[i]['p'][idx_r, idx_c, :, :, :] = p[i][j]
        callbacks.on_batch_end(b, batch_logs)
    runtime = (current_time_millis() - start_time) / 1000
    callbacks.on_epoch_end(0, {'runtime': runtime})
    callbacks.on_train_end()

    for i in range(len(outputs)):
        if len(outputs[i]['shape']) == 2:
            if data_format == 'channels_first':
                shape = (outputs[i]['p'].shape[0], inputs['img_r'],
                         inputs['img_c'])
            elif data_format == 'channels_last':
                shape = (inputs['img_r'], inputs['img_c'],
                         outputs[i]['p'].shape[2])
        elif len(tshape) == 4:
            if data_format == 'channels_first':
                shape = (outputs[i]['p'].shape[2], inputs['img_r'],
                         inputs['img_c'])
            elif data_format == 'channels_last':
                shape = (inputs['img_r'], inputs['img_c'],
                         outputs[i]['p'].shape[4])

        count = np.zeros(shape)
        outputs[i]['img'] = np.zeros(shape)
        if len(outputs[i]['p'].shape) == 3:
            if data_format == 'channels_first':
                nb_rows = outputs[i]['p'].shape[1]
                nb_cols = outputs[i]['p'].shape[2]
            elif data_format == 'channels_last':
                nb_rows = outputs[i]['p'].shape[0]
                nb_cols = outputs[i]['p'].shape[1]
        elif len(outputs[i]['p'].shape) == 5:
            nb_rows = outputs[i]['p'].shape[0]
            nb_cols = outputs[i]['p'].shape[1]

        for j in range(nb_rows):
            for k in range(nb_cols):
                top = min(j * inputs['offset_r'],
                          inputs['img_r'] - inputs['r'])
                bottom = min(j * inputs['offset_r'] + inputs['r'],
                             inputs['img_r'])
                left = min(k * inputs['offset_c'],
                           inputs['img_c'] - inputs['c'])
                right = min(k * inputs['offset_c'] + inputs['c'],
                            inputs['img_c'])

                if data_format == 'channels_first':
                    outputs[i]['img'][:, top:bottom, left:right] += \
                        outputs[i]['p'][:, j, k]
                    count[:, top:bottom, left:right] += 1
                elif data_format == 'channels_last':
                    outputs[i]['img'][top:bottom, left:right, :] += \
                        outputs[i]['p'][j, k, :]
                    count[top:bottom, left:right, :] += 1
        outputs[i]['img'] /= count
        del outputs[i]['p']
        del outputs[i]['shape']
    return history.history, outputs
Beispiel #15
0
    def fit_generator(self,
                      generator,
                      steps_per_epoch=None,
                      epochs=1,
                      verbose=1,
                      callbacks=None,
                      validation_data=None,
                      max_queue_size=10,
                      workers=1,
                      use_multiprocessing=False,
                      shuffle=True):
        if workers > 0:
            enqueuer = GeneratorEnqueuer(
                generator, use_multiprocessing=use_multiprocessing)

            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            output_generator = generator

        callback_list = CallbackList(callbacks=callbacks)
        callback_list.set_model(self)
        callback_list.on_train_begin()

        hist = {'loss': [], 'val_loss': []}
        for epoch in range(epochs):
            seen = 0
            epoch_logs = {'loss': 0, 'val_loss': 0}
            t = trange(steps_per_epoch) if verbose == 1 else range(
                steps_per_epoch)
            for _ in t:
                generator_output = next(output_generator)
                x, y = generator_output
                if x is None or len(x) == 0:
                    # Handle data tensors support when no input given
                    # step-size = 1 for data tensors
                    batch_size = 1
                elif isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_loss, batch_metrics = self.train_on_batch(x, y)
                epoch_logs['loss'] += batch_loss * batch_size
                seen += batch_size

            for k in epoch_logs:
                epoch_logs[k] /= seen
            hist['loss'].append(epoch_logs['loss'])

            if validation_data:
                val_loss_and_metrics = self.evaluate(validation_data[0],
                                                     validation_data[1])
                hist['val_loss'].append(val_loss_and_metrics[0])
                epoch_logs.update({'val_loss': val_loss_and_metrics[0]})

            callback_list.on_epoch_end(epoch, epoch_logs)

            if self.stop_training:
                break
        if workers > 0:
            enqueuer.stop()
        return hist
Beispiel #16
0
    def _train_by_batch(self):
        # batch finite generator should be loaded within epoch loop
        logger.info('Start training by batch')
        self.validation_xy = self.load_data('val', feed_mode='all')
        do_validation = bool(self.validation_xy)

        # prepare display labels in tensorboard
        out_labels = self.model._get_deduped_metrics_names()
        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        # prepare callbacks
        self.model.history = History()
        callbacks = [BaseLogger()] + (self.callbacks
                                      or []) + [self.model.history]
        # callbacks = (self.callbacks or []) + [self.model.history]
        if self.verbose:
            callbacks += [ProgbarLogger(count_mode='samples')]
        callbacks = CallbackList(callbacks)

        # it's possible to callback a different model than this model
        if hasattr(self.model, 'callback_model') and self.model.callback_model:
            callback_model = self.model.callback_model
        else:
            callback_model = self.model
        callbacks.set_model(callback_model)
        callbacks.set_params({
            'epochs': self.epochs,
            'samples': self.data.nb_train,
            'verbose': self.verbose,
            'do_validation': do_validation,
            'metrics': callback_metrics,
        })
        callbacks.on_train_begin()

        for epoch in range(self.epochs):
            start_e = time()
            callbacks.on_epoch_begin(epoch)
            xy_gen = self.load_data('train', feed_mode='batch')
            logger.info('New training epoch')
            for batch_index, (x, y) in enumerate(xy_gen):
                # build batch logs
                batch_logs = {}
                if isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size
                callbacks.on_batch_begin(batch_index, batch_logs)
                outs = self.model.train_on_batch(x, y)

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(batch_index, batch_logs)

                if (batch_index + 1) % 1000 == 0 and do_validation:
                    val_outs = self.model.evaluate(*self.validation_xy,
                                                   batch_size=81920,
                                                   verbose=0)
                    batch_logs = {}
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    for l, o in zip(out_labels, val_outs):
                        batch_logs['val_' + l] = o
                    print(' - Eval inside: %.6f' % val_outs[0])
                    for cb in self.callbacks:
                        if cb.__class__ == tensorBoard:
                            cb.on_batch_end(batch_index,
                                            batch_logs,
                                            count=False)

            epoch_logs = {}
            if do_validation:
                val_outs = self.model.evaluate(*self.validation_xy,
                                               batch_size=81920,
                                               verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

            callbacks.on_batch_end(epoch, epoch_logs)
            callbacks.on_epoch_end(epoch, epoch_logs)

            elapsed_e = timedelta(seconds=int(time() - start_e))
            self.send_metric('elapsed_per_epoch', elapsed_e)

            if not self.no_save and do_validation and (epoch !=
                                                       self.epochs - 1):
                self.model.save(
                    'results/trained_models/%s_ctr_model_%.4f_epoch_%d.h5' %
                    (self.sess_id, val_outs[0], epoch))

        callbacks.on_train_end()
        return self.model.history
Beispiel #17
0
    def fit_ph(self,
               x,
               y,
               batch_size=None,
               nsteps=None,
               epochs=1,
               verbose=1,
               callbacks=None,
               validation_data=None):
        hist = {'loss': [], 'val_loss': []}
        total_len = len(y[0]) if type(y) is list else len(y)
        if nsteps is None:
            nsteps = total_len // batch_size
        callback_list = CallbackList(callbacks=callbacks)
        callback_list.set_model(self)
        callback_list.on_train_begin()
        assert epochs > 0
        g = batchify(x, y, batch_size)
        for epoch in range(epochs):
            t = trange(nsteps) if verbose == 1 else range(nsteps)
            metrics_val = []
            curr_loss = None
            for it in t:
                try:
                    x_, y_ = next(g)
                except StopIteration:
                    g = batchify(x, y, batch_size)
                    x_, y_ = next(g)
                feed_dict = self._make_feed_dict(x_,
                                                 y_,
                                                 is_training_phase=True)
                _, batch_loss, batch_metrics = self.session.run(
                    [self.train_op, self.loss, self.metrics],
                    feed_dict=feed_dict)
                if len(metrics_val):
                    metrics_val = list(
                        map(lambda x: x[0] * 0.95 + x[1] * 0.05,
                            zip(metrics_val, batch_metrics)))
                else:
                    metrics_val = batch_metrics
                curr_loss = batch_loss if curr_loss is None else curr_loss * 0.95 + batch_loss * 0.05
                if verbose == 1:
                    t.set_postfix(loss="%.4f" % curr_loss)
                if verbose == 2:
                    if it % 1000 == 0:
                        print("%s %i/%i, loss=%.5f" %
                              (datetime.datetime.now().strftime("%H:%M:%S"),
                               it, nsteps, curr_loss),
                              flush=True)

            hist['loss'].append(curr_loss)
            logs = {'loss': curr_loss}
            if validation_data:
                val_loss_and_metrics = self.evaluate(validation_data[0],
                                                     validation_data[1])
                hist['val_loss'].append(val_loss_and_metrics[0])
                logs.update({'val_loss': val_loss_and_metrics[0]})

            if verbose:
                if validation_data:
                    print(
                        "Epoch %i, loss=%.3f, metrics=%s; val=%s" %
                        (epoch, curr_loss, metrics_val, val_loss_and_metrics))
                else:
                    print("Epoch %i, loss=%.3f, metrics=%s" %
                          (epoch, curr_loss, metrics_val))

            callback_list.on_epoch_end(epoch=epoch, logs=logs)
            if self.stop_training:
                break
        return hist
Beispiel #18
0
    def fit_dataset(self,
                    dataset,
                    steps_per_epoch=None,
                    batch_size=32,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    on_sample=None,
                    on_scores=None):
        """Train the model on the given dataset for a given number of epochs.

        Arguments
        ---------
            dataset: Instance of `BaseDataset` that provides the data
                     to train on.
            steps_per_epoch: int or None, number of gradient updates before
                             considering an epoch has passed. If None it is set
                             to be `len(dataset.train_data) / batch_size`.
            batch_size: int, number of samples per gradient update
            epochs: int, number of times to iterate `steps_per_epoch` times
            verbose: {0, >0}, whether to employ the progress bar Keras
                     callback or not
            callbacks: list of Keras callbacks to be called during training
            on_sample: callable that accepts the sampler, idxs, w, scores
            on_scores: callable that accepts the sampler and scores
        """
        try:
            if len(dataset.train_data) < batch_size:
                raise ValueError(("The model cannot be trained with "
                                  "batch_size > training set"))
        except RuntimeError as e:
            assert "no size" in str(e)

        # Set steps_per_epoch properly
        if steps_per_epoch is None:
            steps_per_epoch = len(dataset.train_data) // batch_size

        # Create the callbacks list
        self.history = History()
        callbacks = [BaseLogger()] + (callbacks or []) + [self.history]
        if verbose > 0:
            callbacks += [ProgbarLogger(count_mode="steps")]
        callbacks = CallbackList(callbacks)
        #TODO: Should we be making it possible to call back a different model
        #      than self.model.model?
        callbacks.set_model(self.model.model)
        callbacks.set_params({
            "epochs":
            epochs,
            "steps":
            steps_per_epoch,
            "verbose":
            verbose,
            "do_validation":
            len(dataset.test_data) > 0,
            "metrics":
            self._get_metric_names() +
            ["val_" + n for n in self._get_metric_names()]
        })

        # Create the sampler
        sampler = self.sampler(dataset, batch_size, steps_per_epoch, epochs)

        # Start the training loop
        epoch = 0
        self.model.model.stop_training = False
        callbacks.on_train_begin()
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            for step in range(steps_per_epoch):
                batch_logs = {"batch": step, "size": batch_size}
                callbacks.on_batch_begin(step, batch_logs)

                # Importance sampling is done here
                idxs, (x, y), w = sampler.sample(batch_size)
                # Train on the sampled data
                loss, metrics, scores = self.model.train_batch(x, y, w)
                # Update the sampler
                sampler.update(idxs, scores)

                values = map(lambda x: x.mean(), [loss] + metrics)
                for l, o in zip(self._get_metric_names(), values):
                    batch_logs[l] = o
                callbacks.on_batch_end(step, batch_logs)

                if on_scores is not None:
                    on_scores(sampler, self._latest_scores)

                if on_sample is not None:
                    on_sample(sampler, self._latest_sample_event["idxs"],
                              self._latest_sample_event["w"],
                              self._latest_sample_event["predicted_scores"])

                if self.model.model.stop_training:
                    break

            # Evaluate now that an epoch passed
            epoch_logs = {}
            if len(dataset.test_data) > 0:
                val = self.model.evaluate(*dataset.test_data[:],
                                          batch_size=batch_size)
                epoch_logs = {
                    "val_" + l: o
                    for l, o in zip(self._get_metric_names(), val)
                }
            callbacks.on_epoch_end(epoch, epoch_logs)
            if self.model.model.stop_training:
                break
            epoch += 1
        callbacks.on_train_end()

        return self.history
Beispiel #19
0
    def fit(self,
            x,
            y,
            batch_size=None,
            nsteps=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_data=None):
        assert self.is_compiled, "Must compile model first"
        assert epochs > 0
        x = x if type(x) is list else [x]
        y = y if type(y) is list else [y]
        if nsteps is None:
            total_len = len(y[0]) if type(y) is list else len(y)
            nsteps = total_len // batch_size
        # BaseLogger should always be the first metric since it computes the stats on epoch end
        base_logger = BaseLogger(
            stateful_metrics=["val_%s" % m for m in self.metrics_name] +
            ['val_loss', 'size'])
        base_logger_params = {'metrics': ['loss'] + self.metrics_name}
        if validation_data:
            base_logger_params['metrics'] += [
                'val_%s' % m for m in base_logger_params['metrics']
            ]
        base_logger.set_params(base_logger_params)
        hist = History()
        if callbacks is None:
            callbacks = [base_logger] + [hist]
        elif type(callbacks) is list:
            callbacks = [base_logger] + callbacks + [hist]
        else:
            callbacks = [base_logger] + [callbacks] + [hist]
        callback_list = CallbackList(callbacks=callbacks)
        callback_list.set_model(self)
        callback_list.on_train_begin()
        self.callbacks = callback_list
        for epoch in range(epochs):
            g = batchify(x, y, batch_size) if batch_size else None
            t = trange(nsteps) if verbose == 1 else range(nsteps)
            callback_list.on_epoch_begin(epoch)
            for it in t:
                x_, y_ = next(g) if g else (None, None)
                batch_logs = self.train_on_batch(x_, y_)
                callback_list.on_batch_end(it, batch_logs)
                curr_loss = base_logger.totals['loss'] / base_logger.seen
                if verbose == 1:
                    t.set_postfix(loss="%.4f" % curr_loss)
                if verbose == 2:
                    if it % 1000 == 0:
                        print("%s %i/%i, loss=%.5f" %
                              (datetime.datetime.now().strftime("%H:%M:%S"),
                               it, nsteps, curr_loss),
                              flush=True)

            if validation_data:
                val_logs = self.evaluate(validation_data[0],
                                         validation_data[1])
                base_logger.on_batch_end(None, val_logs)

            epoch_logs = {}
            callback_list.on_epoch_end(epoch=epoch, logs=epoch_logs)

            if verbose:
                if validation_data:
                    to_print = ['loss'] + self.metrics_name + ['val_loss'] + [
                        'val_%s' % m for m in self.metrics_name
                    ]
                else:
                    to_print = ['loss'] + self.metrics_name
                prog = ", ".join([
                    "%s=%.4f" % (name, hist.history[name][-1])
                    for name in to_print
                ])
                print("Epoch %i, %s" % (epoch, prog), flush=True)

            if self.stop_training:
                break

        return hist.history
Beispiel #20
0
    do_validation = True

    callbacks._set_model(model)
    callbacks._set_params({
        'batch_size': batch_size,
        'nb_epoch': nb_epoch,
        'nb_sample': nb_train_sample,
        'verbose': 1,
        'do_validation': do_validation,
        'metrics': metrics,
    })

    ##########################
    # TRAINING
    ##########################
    callbacks.on_train_begin()

    model.stop_training = False
    for epoch in range(nb_epoch):
        callbacks.on_epoch_begin(epoch)

        if shuffle_on_epoch_start:
            X_train, y_train = util.shuffle_data(X_train, y_train)

        # train
        util.train_on_batch(model,
                            X_train,
                            y_train,
                            nb_classes,
                            callbacks=callbacks,
                            normalize=normalize_data,
Beispiel #21
0
def predict(model,
            batch_size,
            num_outputs,
            save_path,
            evaluate=False,
            liver_only=False,
            save_predictions=False,
            initial_epoch=0,
            **kwargs):
    model, callbacks, gen = prepare_model(model=model,
                                          num_outputs=num_outputs,
                                          liver_only=liver_only,
                                          evaluate=evaluate,
                                          **kwargs)

    # Set up prediction file.
    if save_predictions:
        save_path = os.path.join(save_path, "predictions.zarr")
        if os.path.exists(save_path):
            os.remove(save_path)

    # Initialize callbacks
    val_callback_list = [BaseLogger()]
    if not liver_only:
        val_callback_list.extend(
            [callbacks['dice_lesion'], callbacks['dice_lesion_inliver']])
    if len(model.outputs) == 2 or liver_only:
        val_callback_list.append(callbacks['dice_liver'])
    val_callbacks = CallbackList(val_callback_list)
    val_callbacks.set_params({
        'nb_epoch': 0,
        'nb_sample': 0,
        'verbose': False,
        'do_validation': True,
        'metrics': model.metrics_names
    })
    val_callbacks.on_train_begin()
    val_callbacks.on_epoch_begin(0)

    # Create theano function
    if evaluate:
        inputs = model.inputs + model.targets + model.sample_weights
        if model.uses_learning_phase and \
                not isinstance(K.learning_phase(), int):
            inputs += [K.learning_phase()]
        predict_function = K.function(inputs,
                                      model.outputs + [model.total_loss] +
                                      model.metrics_tensors,
                                      updates=model.state_updates)
    else:
        inputs = model.inputs
        if model.uses_learning_phase and \
                not isinstance(K.learning_phase(), int):
            inputs += [K.learning_phase()]
        predict_function = K.function(inputs,
                                      model.outputs,
                                      updates=model.state_updates)

    # Predict for all data.
    print(' > Predicting...')
    for key in gen:
        print(' - DATA: {}'.format(key))

        # Duplicate inputs and outputs (and add outputs) as necessary.
        flow = repeat_flow(gen[key].flow(), num_outputs=num_outputs)

        # Set up file.
        if save_predictions:
            zgroup = zarr.open_group(store=save_path, mode='a', path="/")
            zarr_kwargs = {
                'chunks': (1, 512, 512),
                'compressor': zarr.Blosc(cname='lz4', clevel=9, shuffle=1)
            }

        # Predict and write to file.
        batch_num = 0
        for vol_num, volume in enumerate(flow):
            print("Predicting on `{}` - {}/{}"
                  "".format(key, vol_num + 1, len(gen[key])))

            # Begin writing to file.
            if save_predictions:
                vol_idx = volume[-1]
                subgroup = zgroup.create_group(str(vol_idx))
                num_channels = np.sum(model.output_shape[i][1] \
                                                   for i in range(num_outputs))
                output_shape = \
                       (len(volume[0]), num_channels)+model.output_shape[0][2:]
                subgroup.empty("volume",
                               shape=output_shape,
                               dtype=np.float32,
                               **zarr_kwargs)
                segmentation = volume[1]
                if isinstance(segmentation, list):
                    segmentation = segmentation[0]
                subgroup.create_dataset("segmentation",
                                        shape=segmentation.shape,
                                        data=segmentation,
                                        dtype=np.int16,
                                        **zarr_kwargs)

            # Iterate through volume batch-wise.
            for idx0, idx1 in zip(
                    range(0, len(volume[0]), batch_size),
                    range(batch_size,
                          len(volume[0]) + batch_size + 1, batch_size)):
                # Prepare data for joint evaluation and prediction.
                if evaluate:
                    batch = (volume[0][idx0:idx1], volume[1][idx0:idx1])
                    x, y, sample_weights = model._standardize_user_data(
                        batch[0], batch[1])
                    ins = x + y + sample_weights
                else:
                    batch = (volume[0][idx0:idx1], )
                    ins = _standardize_input_data(batch[0],
                                                  model._feed_input_names,
                                                  model._feed_input_shapes,
                                                  check_batch_axis=False,
                                                  exception_prefix='input')
                if model.uses_learning_phase and \
                        not isinstance(K.learning_phase(), int):
                    ins += [0.]

                # Jointly evaluate and predict.
                outputs = predict_function(ins)
                if num_outputs == 1:
                    predictions = outputs[0:1]
                    if evaluate:
                        val_metrics = outputs[1:]
                elif num_outputs == 2:
                    predictions = outputs[0:2]
                    if evaluate:
                        val_metrics = outputs[2:]
                else:
                    raise ValueError("num_outputs must be 1 or 2")

                # Write predictions.
                predictions = np.concatenate(predictions, axis=1)
                subgroup['volume'][idx0:idx1] = predictions

                # Update metrics
                if evaluate:
                    val_logs = OrderedDict(
                        zip(model.metrics_names, val_metrics))
                    val_logs.update({
                        'batch': batch_num,
                        'size': len(batch[0])
                    })
                    val_callbacks.on_batch_end(batch_num, val_logs)

                batch_num += 1

    if evaluate:
        # Update metrics
        val_callbacks.on_epoch_end(0, val_logs)

        # Output metrics
        for m in val_logs:
            if m not in ['batch', 'size']:
                print("{}: {}".format(m, val_logs[m]))