示例#1
0
def main(args):
    model_id = build_model_id(args)

    model_path = build_model_path(args, model_id)

    setup_model_dir(args, model_path)

    if 'background' in args.mode:
        callback_logger = logging.info
        sys.stdout, sys.stderr = setup_logging(
            os.path.join(model_path, 'model.log'))
        verbose = 0
    else:
        callback_logger = callable_print
        verbose = 1

    json_cfg = load_model_json(args)
    json_cfg['model_path'] = model_path
    json_cfg['stdout'] = sys.stdout
    json_cfg['stderr'] = sys.stderr
    json_cfg['logger'] = callback_logger
    json_cfg['verbose'] = verbose

    config = ModelConfig(**json_cfg)

    if 'persistent' in args.mode:
        save_model_info(config, model_path)

    sys.path.append(args.model_dir)
    import model
    from model import fit

    model.fit(config)
示例#2
0
def main(args):
    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args)

    rng = np.random.RandomState(args.seed)

    x_train, y_train = load_model_data(args.train_file,
                                       args.data_name,
                                       args.target_name,
                                       n=args.n_train)

    x_valid, y_valid = load_model_data(args.validation_file,
                                       args.data_name,
                                       args.target_name,
                                       n=args.n_validation)

    train_files = args.extra_train_file + [args.train_file]
    train_files_iter = itertools.cycle(train_files)

    n_classes = max(np.unique(y_train)) + 1
    json_cfg = load_model_json(args, x_train, n_classes)

    sys.path.append(args.model_dir)
    from model import Model
    model_cfg = ModelConfig(**json_cfg)
    model = Model(model_cfg)
    setattr(model, 'stop_training', False)

    best_accuracy = 0.
    best_epoch = 0

    epoch = 1
    iteration = 0

    while True:
        if not keep_training(epoch, best_epoch, model_cfg):
            break

        train_loss = train_one_epoch(model,
                                     x_train,
                                     y_train,
                                     args,
                                     model_cfg,
                                     progress=args.progress)

        val_loss, val_accuracy = validate(model,
                                          x_valid,
                                          y_valid,
                                          args,
                                          model_cfg,
                                          progress=args.progress)

        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_epoch = epoch
            if model_path is not None:
                model.save_weights(model_path + '.npz')
                cPickle.dump(model, open(model_path + '.pkl', 'w'))

        print(
            'epoch={epoch:05d}, iteration={iteration:05d}, loss={loss:.04f}, val_loss={val_loss:.04f}, val_acc={val_acc:.04f} best=[accuracy={best_accuracy:.04f} epoch={best_epoch:05d}]'
            .format(epoch=epoch,
                    iteration=iteration,
                    loss=train_loss,
                    val_loss=val_loss,
                    val_acc=val_accuracy,
                    best_accuracy=best_accuracy,
                    best_epoch=best_epoch))

        iteration += 1
        if iteration % len(train_files) == 0:
            epoch += 1

        x_train, y_train = load_model_data(next(train_files_iter),
                                           args.data_name,
                                           args.target_name,
                                           n=args.n_train)
示例#3
0
def main(args):
    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args, model_path)

    x_train, y_train = load_model_data(args.train_file,
            args.data_name, args.target_name)
    x_validation, y_validation = load_model_data(
            args.validation_file,
            args.data_name, args.target_name)

    rng = np.random.RandomState(args.seed)

    if args.n_classes > -1:
        n_classes = args.n_classes
    else:
        n_classes = max(y_train)+1

    n_classes, target_names, class_weight = load_target_data(args, n_classes)

    if class_weight is None and args.class_weight_auto:
        n_samples = len(y_train)
        weights = float(n_samples) / (n_classes * np.bincount(y_train))
        if args.class_weight_exponent:
            weights = weights**args.class_weight_exponent
        class_weight = dict(zip(range(n_classes), weights))

    if args.verbose:
        logging.debug("n_classes {0} min {1} max {2}".format(
            n_classes, min(y_train), max(y_train)))

    y_train_one_hot = np_utils.to_categorical(y_train, n_classes)
    y_validation_one_hot = np_utils.to_categorical(y_validation, n_classes)

    if args.verbose:
        logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
        logging.debug("x_train " + str(x_train.shape))

    min_vocab_index = np.min(x_train)
    max_vocab_index = np.max(x_train)

    if args.verbose:
        logging.debug("min vocab index {0} max vocab index {1}".format(
            min_vocab_index, max_vocab_index))

    json_cfg = load_model_json(args, x_train, n_classes)

    if args.verbose:
        logging.debug("loading model")

    sys.path.append(args.model_dir)
    import model
    from model import build_model

    #######################################################################      
    # Subsetting
    #######################################################################      
    if args.subsetting_function:
        subsetter = getattr(M, args.subsetting_function)
    else:
        subsetter = None

    def take_subset(subsetter, path, x, y, y_one_hot, n):
        if subsetter is None:
            return x[0:n], y[0:n], y_one_hot[0:n]
        else:
            mask = subsetter(path)
            idx = np.where(mask)[0]
            idx = idx[0:n]
        return x[idx], y[idx], y_one_hot[idx]

    x_train, y_train, y_train_one_hot = take_subset(
            subsetter, args.train_file,
            x_train, y_train, y_train_one_hot,
            n=args.n_train)

    x_validation, y_validation, y_validation_one_hot = take_subset(
            subsetter, args.validation_file,
            x_validation, y_validation, y_validation_one_hot,
            n=args.n_validation)

    #######################################################################      
    # Preprocessing
    #######################################################################      
    if args.preprocessing_class:
        preprocessor = getattr(M, args.preprocessing_class)(seed=args.seed)
    else:
        preprocessor = modeling.preprocess.NullPreprocessor()

    if args.verbose:
        logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
        logging.debug("x_train " + str(x_train.shape))

    model_cfg = ModelConfig(**json_cfg)
    if args.verbose:
        logging.info("model_cfg " + str(model_cfg))
    net = build_model(model_cfg)
    setattr(net, 'stop_training', False)

    marshaller = None
    if isinstance(net, keras.models.Graph):
        marshaller = getattr(model, args.graph_marshalling_class)()

    logging.info('model has {n_params} parameters'.format(
        n_params=count_parameters(net)))

    if len(args.extra_train_file) > 1:
        callbacks = keras.callbacks.CallbackList()
    else:
        callbacks = []

    save_model_info(args, model_path, model_cfg)


    callback_logger = logging.info if args.log else callable_print

    #######################################################################      
    # Callbacks that need validation set predictions.
    #######################################################################      

    pc = PredictionCallback(x_validation, callback_logger,
            marshaller=marshaller, batch_size=model_cfg.batch_size)
    callbacks.append(pc)

    if args.classification_report:
        cr = ClassificationReport(x_validation, y_validation,
                callback_logger,
                target_names=target_names)
        pc.add(cr)
    
    if args.confusion_matrix:
        cm = ConfusionMatrix(x_validation, y_validation,
                callback_logger)
        pc.add(cm)

    def get_mode(metric_name):
        return {
                'val_loss': 'min',
                'val_acc': 'max',
                'val_f1': 'max',
                'val_f2': 'max',
                'val_f0.5': 'max'
                }[metric_name]

    if args.early_stopping or args.early_stopping_metric is not None:
        es = EarlyStopping(monitor=args.early_stopping_metric,
                mode=get_mode(args.early_stopping_metric),
                patience=model_cfg.patience,
                verbose=1)
        cb = DelegatingMetricCallback(
                x_validation, y_validation, callback_logger,
                delegate=es,
                metric_name=args.early_stopping_metric,
                marshaller=marshaller)
        pc.add(cb)

    if not args.no_save:
        if args.save_all_checkpoints:
            filepath = model_path + '/model-{epoch:04d}.h5'
        else:
            filepath = model_path + '/model.h5'
        mc = ModelCheckpoint(
            filepath=filepath,
            mode=get_mode(args.checkpoint_metric),
            verbose=1,
            monitor=args.checkpoint_metric,
            save_best_only=not args.save_every_epoch)
        cb = DelegatingMetricCallback(
                x_validation, y_validation, callback_logger,
                delegate=mc,
                metric_name=args.checkpoint_metric,
                marshaller=marshaller)
        pc.add(cb)

    if model_cfg.optimizer == 'SGD':
        callbacks.append(SingleStepLearningRateSchedule(patience=10))

    if len(args.extra_train_file) > 1:
        args.extra_train_file.append(args.train_file)
        logging.info("Using the following files for training: " +
                ','.join(args.extra_train_file))

        train_file_iter = itertools.cycle(args.extra_train_file)
        current_train = args.train_file

        callbacks._set_model(net)
        callbacks.on_train_begin(logs={})

        epoch = batch = 0

        while True:
            x_train, y_train_one_hot = preprocessor.fit_transform(
                    x_train, y_train_one_hot)
            x_validation, y_validation_one_hot = preprocessor.transform(
                    x_validation, y_validation_one_hot)

            iteration = batch % len(args.extra_train_file)

            logging.info("epoch {epoch} iteration {iteration} - training with {train_file}".format(
                    epoch=epoch, iteration=iteration, train_file=current_train))
            callbacks.on_epoch_begin(epoch, logs={})

            n_train = x_train.shape[0]

            callbacks.on_batch_begin(batch, logs={'size': n_train})

            index_array = np.arange(n_train)
            if args.shuffle:
                rng.shuffle(index_array)

            batches = keras.models.make_batches(n_train, model_cfg.batch_size)
            logging.info("epoch {epoch} iteration {iteration} - starting {n_batches} batches".format(
                    epoch=epoch, iteration=iteration, n_batches=len(batches)))

            avg_train_loss = avg_train_accuracy = 0.
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]

                if isinstance(net, keras.models.Graph):
                    train_data = marshaller.marshal(
                            x_train[batch_ids], y_train_one_hot[batch_ids])
                    train_loss = net.train_on_batch(
                            train_data, class_weight=class_weight)
                    # It looks like train_on_batch returns a different
                    # type for graph than sequential models.
                    train_loss = train_loss[0]
                    train_accuracy = 0.
                else:
                    train_loss, train_accuracy = net.train_on_batch(
                            x_train[batch_ids], y_train_one_hot[batch_ids],
                            accuracy=True, class_weight=class_weight)

                batch_end_logs = {'loss': train_loss, 'accuracy': train_accuracy}

                avg_train_loss = (avg_train_loss * batch_index + train_loss)/(batch_index + 1)
                avg_train_accuracy = (avg_train_accuracy * batch_index + train_accuracy)/(batch_index + 1)

                callbacks.on_batch_end(batch,
                        logs={'loss': train_loss, 'accuracy': train_accuracy})

            logging.info("epoch {epoch} iteration {iteration} - finished {n_batches} batches".format(
                    epoch=epoch, iteration=iteration, n_batches=len(batches)))

            logging.info("epoch {epoch} iteration {iteration} - loss: {loss} - acc: {acc}".format(
                    epoch=epoch, iteration=iteration, loss=avg_train_loss, acc=avg_train_accuracy))

            batch += 1

            # Validation frequency (this if-block) doesn't necessarily
            # occur in the same iteration as beginning of an epoch
            # (next if-block), so net.evaluate appears twice here.
            kwargs = {
                    'batch_size': model_cfg.batch_size,
                    'verbose': 0 if args.log else 1 
                    }
            pargs = []
            validation_data = {}
            if isinstance(net, keras.models.Graph):
                validation_data = marshaller.marshal(
                        x_validation, y_validation_one_hot)
                pargs = [validation_data]
            else:
                pargs = [x_validation, y_validation_one_hot]
                kwargs['show_accuracy'] = True

            if (iteration + 1) % args.validation_freq == 0:
                if isinstance(net, keras.models.Graph):
                    val_loss = net.evaluate(*pargs, **kwargs)
                    y_hat = net.predict(validation_data, batch_size=model_cfg.batch_size)
                    val_acc = accuracy_score(y_validation, np.argmax(y_hat['output'], axis=1))
                else:
                    val_loss, val_acc = net.evaluate(
                            *pargs, **kwargs)
                logging.info("epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}".format(
                        epoch=epoch, iteration=iteration, val_loss=val_loss, val_acc=val_acc))
                epoch_end_logs = {'iteration': iteration, 'val_loss': val_loss, 'val_acc': val_acc}
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if batch % len(args.extra_train_file) == 0:
                if isinstance(net, keras.models.Graph):
                    val_loss = net.evaluate(*pargs, **kwargs)
                    y_hat = net.predict(validation_data, batch_size=model_cfg.batch_size)
                    val_acc = accuracy_score(y_validation, np.argmax(y_hat['output'], axis=1))
                else:
                    val_loss, val_acc = net.evaluate(
                            *pargs, **kwargs)
                logging.info("epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}".format(
                        epoch=epoch, iteration=iteration, val_loss=val_loss, val_acc=val_acc))
                epoch_end_logs = {'iteration': iteration, 'val_loss': val_loss, 'val_acc': val_acc}
                epoch += 1
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if net.stop_training:
                logging.info("epoch {epoch} iteration {iteration} - done training".format(
                    epoch=epoch, iteration=iteration))
                break

            current_train = next(train_file_iter)
            x_train, y_train = load_model_data(current_train,
                    args.data_name, args.target_name)
            y_train_one_hot = np_utils.to_categorical(y_train, n_classes)

            if epoch > args.n_epochs:
                break

        callbacks.on_train_end(logs={})
    else:
        x_train, y_train_one_hot = preprocessor.fit_transform(
                x_train, y_train_one_hot)
        x_validation, y_validation_one_hot = preprocessor.transform(
                x_validation, y_validation_one_hot)

        if isinstance(net, keras.models.Graph):
            train_data = marshaller.marshal(
                    x_train, y_train_one_hot)
            validation_data = marshaller.marshal(
                    x_validation, y_validation_one_hot)
            net.fit(train_data,
                shuffle=args.shuffle,
                nb_epoch=args.n_epochs,
                batch_size=model_cfg.batch_size,
                validation_data=validation_data,
                callbacks=callbacks,
                class_weight=class_weight,
                verbose=2 if args.log else 1)
        else:
            net.fit(x_train, y_train_one_hot,
                shuffle=args.shuffle,
                nb_epoch=args.n_epochs,
                batch_size=model_cfg.batch_size,
                show_accuracy=True,
                validation_data=(x_validation, y_validation_one_hot),
                callbacks=callbacks,
                class_weight=class_weight,
                verbose=2 if args.log else 1)
示例#4
0
def main(args):
    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args)

    rng = np.random.RandomState(args.seed)

    x_train, y_train = load_model_data(args.train_file,
            args.data_name, args.target_name,
            n=args.n_train)

    x_valid, y_valid = load_model_data(
            args.validation_file,
            args.data_name, args.target_name,
            n=args.n_validation)

    train_files = args.extra_train_file + [args.train_file]
    train_files_iter = itertools.cycle(train_files)

    n_classes = max(np.unique(y_train)) + 1
    json_cfg = load_model_json(args, x_train, n_classes)

    sys.path.append(args.model_dir)
    from model import Model
    model_cfg = ModelConfig(**json_cfg)
    model = Model(model_cfg)
    setattr(model, 'stop_training', False)
    
    best_accuracy = 0.
    best_epoch = 0
    
    epoch = 1
    iteration = 0
    
    while True:
        if not keep_training(epoch, best_epoch, model_cfg):
            break

        train_loss = train_one_epoch(model, x_train, y_train,
                args, model_cfg, progress=args.progress)

        val_loss, val_accuracy = validate(model, x_valid, y_valid,
                args, model_cfg, progress=args.progress)

        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_epoch = epoch
            if model_path is not None:
                model.save_weights(model_path + '.npz')
                cPickle.dump(model, open(model_path + '.pkl', 'w'))

        print('epoch={epoch:05d}, iteration={iteration:05d}, loss={loss:.04f}, val_loss={val_loss:.04f}, val_acc={val_acc:.04f} best=[accuracy={best_accuracy:.04f} epoch={best_epoch:05d}]'.format(
            epoch=epoch, iteration=iteration,
            loss=train_loss, val_loss=val_loss, val_acc=val_accuracy, 
            best_accuracy=best_accuracy, best_epoch=best_epoch))
    
        iteration += 1
        if iteration % len(train_files) == 0:
            epoch += 1

        x_train, y_train = load_model_data(
                next(train_files_iter),
                args.data_name, args.target_name,
                n=args.n_train)
示例#5
0
def main(args):
    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args, model_path)

    x_train, y_train = load_model_data(args.train_file,
            args.data_name, args.target_name)
    x_validation, y_validation = load_model_data(
            args.validation_file,
            args.data_name, args.target_name)

    rng = np.random.RandomState(args.seed)

    if args.n_classes > -1:
        n_classes = args.n_classes
    else:
        n_classes = max(y_train)+1

    n_classes, target_names, class_weight = load_target_data(args, n_classes)

    if len(class_weight) == 0:
        n_samples = len(y_train)
        print('n_samples', n_samples)
        print('classes', range(n_classes))
        print('weights', n_samples / (n_classes * np.bincount(y_train)))
        class_weight = dict(zip(range(n_classes),
            n_samples / (n_classes * np.bincount(y_train))))
    print('class_weight', class_weight)

    logging.debug("n_classes {0} min {1} max {2}".format(
        n_classes, min(y_train), max(y_train)))

    y_train_one_hot = np_utils.to_categorical(y_train, n_classes)
    y_validation_one_hot = np_utils.to_categorical(y_validation, n_classes)

    logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
    logging.debug("x_train " + str(x_train.shape))

    min_vocab_index = np.min(x_train)
    max_vocab_index = np.max(x_train)
    logging.debug("min vocab index {0} max vocab index {1}".format(
        min_vocab_index, max_vocab_index))

    json_cfg = load_model_json(args, x_train, n_classes)

    logging.debug("loading model")

    sys.path.append(args.model_dir)
    import model
    from model import build_model

    #######################################################################      
    # Subsetting
    #######################################################################      
    if args.subsetting_function:
        subsetter = getattr(model, args.subsetting_function)
    else:
        subsetter = None

    def take_subset(subsetter, path, x, y, y_one_hot, n):
        if subsetter is None:
            return x[0:n], y[0:n], y_one_hot[0:n]
        else:
            mask = subsetter(path)
            idx = np.where(mask)[0]
            idx = idx[0:n]
        return x[idx], y[idx], y_one_hot[idx]

    x_train, y_train, y_train_one_hot = take_subset(
            subsetter, args.train_file,
            x_train, y_train, y_train_one_hot,
            n=args.n_train)

    x_validation, y_validation, y_validation_one_hot = take_subset(
            subsetter, args.validation_file,
            x_validation, y_validation, y_validation_one_hot,
            n=args.n_validation)

    #######################################################################      
    # Preprocessing
    #######################################################################      
    if args.preprocessing_class:
        preprocessor = getattr(model, args.preprocessing_class)(seed=args.seed)
    else:
        preprocessor = modeling.preprocess.NullPreprocessor()

    logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
    logging.debug("x_train " + str(x_train.shape))

    model_cfg = ModelConfig(**json_cfg)
    logging.info("model_cfg " + str(model_cfg))
    model = build_model(model_cfg)
    setattr(model, 'stop_training', False)

    logging.info('model has {n_params} parameters'.format(
        n_params=count_parameters(model)))

    if len(args.extra_train_file) > 1:
        callbacks = keras.callbacks.CallbackList()
    else:
        callbacks = []

    save_model_info(args, model_path, model_cfg)

    if not args.no_save:
        if args.save_all_checkpoints:
            filepath = model_path + '/model-{epoch:04d}.h5'
        else:
            filepath = model_path + '/model.h5'
        callbacks.append(ModelCheckpoint(
            filepath=filepath,
            verbose=1,
            save_best_only=not args.save_every_epoch))

    callback_logger = logging.info if args.log else callable_print

    if args.n_epochs < sys.maxsize:
        # Number of epochs overrides patience.  If the number of epochs
        # is specified on the command line, the model is trained for
        # exactly that number; otherwise, the model is trained with
        # early stopping using the patience specified in the model 
        # configuration.
        callbacks.append(EarlyStopping(
            monitor='val_loss', patience=model_cfg.patience, verbose=1))

    if args.classification_report:
        cr = ClassificationReport(x_validation, y_validation,
                callback_logger,
                target_names=target_names)
        callbacks.append(cr)

    if model_cfg.optimizer == 'SGD':
        callbacks.append(SingleStepLearningRateSchedule(patience=10))

    if len(args.extra_train_file) > 1:
        args.extra_train_file.append(args.train_file)
        logging.info("Using the following files for training: " +
                ','.join(args.extra_train_file))

        train_file_iter = itertools.cycle(args.extra_train_file)
        current_train = args.train_file

        callbacks._set_model(model)
        callbacks.on_train_begin(logs={})

        epoch = batch = 0

        while True:
            x_train, y_train_one_hot = preprocessor.fit_transform(
                    x_train, y_train_one_hot)
            x_validation, y_validation_one_hot = preprocessor.transform(
                    x_validation, y_validation_one_hot)

            iteration = batch % len(args.extra_train_file)

            logging.info("epoch {epoch} iteration {iteration} - training with {train_file}".format(
                    epoch=epoch, iteration=iteration, train_file=current_train))
            callbacks.on_epoch_begin(epoch, logs={})

            n_train = x_train.shape[0]

            callbacks.on_batch_begin(batch, logs={'size': n_train})

            index_array = np.arange(n_train)
            if args.shuffle:
                rng.shuffle(index_array)

            batches = keras.models.make_batches(n_train, model_cfg.batch_size)
            logging.info("epoch {epoch} iteration {iteration} - starting {n_batches} batches".format(
                    epoch=epoch, iteration=iteration, n_batches=len(batches)))

            avg_train_loss = avg_train_accuracy = 0.
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]

                if isinstance(model, keras.models.Graph):
                    data = {
                            'input': x_train[batch_ids],
                            'output': y_train_one_hot[batch_ids]
                            }
                    train_loss = model.train_on_batch(data, class_weight=class_weight)
                    train_accuracy = 0.
                else:
                    train_loss, train_accuracy = model.train_on_batch(
                            x_train[batch_ids], y_train_one_hot[batch_ids],
                            accuracy=True, class_weight=class_weight)

                batch_end_logs = {'loss': train_loss, 'accuracy': train_accuracy}

                avg_train_loss = (avg_train_loss * batch_index + train_loss)/(batch_index + 1)
                avg_train_accuracy = (avg_train_accuracy * batch_index + train_accuracy)/(batch_index + 1)

                callbacks.on_batch_end(batch,
                        logs={'loss': train_loss, 'accuracy': train_accuracy})

            logging.info("epoch {epoch} iteration {iteration} - finished {n_batches} batches".format(
                    epoch=epoch, iteration=iteration, n_batches=len(batches)))

            logging.info("epoch {epoch} iteration {iteration} - loss: {loss} - acc: {acc}".format(
                    epoch=epoch, iteration=iteration, loss=avg_train_loss, acc=avg_train_accuracy))

            batch += 1

            # Validation frequency (this if-block) doesn't necessarily
            # occur in the same iteration as beginning of an epoch
            # (next if-block), so model.evaluate appears twice here.
            kwargs = { 'verbose': 0 if args.log else 1 }
            pargs = []
            validation_data = {}
            if isinstance(model, keras.models.Graph):
                validation_data = {
                        'input': x_validation,
                        'output': y_validation_one_hot
                        }
                pargs = [validation_data]
            else:
                pargs = [x_validation, y_validation_one_hot]
                kwargs['show_accuracy'] = True

            if (iteration + 1) % args.validation_freq == 0:
                if isinstance(model, keras.models.Graph):
                    val_loss = model.evaluate(*pargs, **kwargs)
                    y_hat = model.predict(validation_data)
                    val_acc = accuracy_score(y_validation, np.argmax(y_hat['output'], axis=1))
                else:
                    val_loss, val_acc = model.evaluate(
                            *pargs, **kwargs)
                logging.info("epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}".format(
                        epoch=epoch, iteration=iteration, val_loss=val_loss, val_acc=val_acc))
                epoch_end_logs = {'iteration': iteration, 'val_loss': val_loss, 'val_acc': val_acc}
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if batch % len(args.extra_train_file) == 0:
                if isinstance(model, keras.models.Graph):
                    val_loss = model.evaluate(*pargs, **kwargs)
                    y_hat = model.predict(validation_data)
                    val_acc = accuracy_score(y_validation, np.argmax(y_hat['output'], axis=1))
                else:
                    val_loss, val_acc = model.evaluate(
                            *pargs, **kwargs)
                logging.info("epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}".format(
                        epoch=epoch, iteration=iteration, val_loss=val_loss, val_acc=val_acc))
                epoch_end_logs = {'iteration': iteration, 'val_loss': val_loss, 'val_acc': val_acc}
                epoch += 1
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if model.stop_training:
                logging.info("epoch {epoch} iteration {iteration} - done training".format(
                    epoch=epoch, iteration=iteration))
                break

            current_train = next(train_file_iter)
            x_train, y_train = load_model_data(current_train,
                    args.data_name, args.target_name)
            y_train_one_hot = np_utils.to_categorical(y_train, n_classes)

            if epoch > args.n_epochs:
                break

        callbacks.on_train_end(logs={})
    else:
        x_train, y_train_one_hot = preprocessor.fit_transform(
                x_train, y_train_one_hot)
        x_validation, y_validation_one_hot = preprocessor.transform(
                x_validation, y_validation_one_hot)
        if isinstance(model, keras.models.Graph):
            data = {
                    'input': x_train,
                    'output': y_train_one_hot
                    }
            validation_data = {
                    'input': x_validation,
                    'output': y_validation_one_hot
                    }
            model.fit(data,
                shuffle=args.shuffle,
                nb_epoch=args.n_epochs,
                batch_size=model_cfg.batch_size,
                validation_data=validation_data,
                callbacks=callbacks,
                class_weight=class_weight,
                verbose=2 if args.log else 1)
            y_hat = model.predict(validation_data)
            print('val_acc %.04f' % 
                    accuracy_score(y_validation, np.argmax(y_hat['output'], axis=1)))
        else:
            model.fit(x_train, y_train_one_hot,
                shuffle=args.shuffle,
                nb_epoch=args.n_epochs,
                batch_size=model_cfg.batch_size,
                show_accuracy=True,
                validation_data=(x_validation, y_validation_one_hot),
                callbacks=callbacks,
                class_weight=class_weight,
                verbose=2 if args.log else 1)
示例#6
0
def main(args):
    if args.gpu >= 0:
        cuda.check_cuda_available()
    xp = cuda.cupy if args.gpu >= 0 else np

    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args)

    x_train, y_train = load_model_data(args.train_file,
            args.data_name, args.target_name,
            n=args.n_train)
    x_validation, y_validation = load_model_data(
            args.validation_file,
            args.data_name, args.target_name,
            n=args.n_validation)

    rng = np.random.RandomState(args.seed)

    N = len(x_train)
    N_validation = len(x_validation)

    n_classes = max(np.unique(y_train)) + 1
    json_cfg = load_model_json(args, x_train, n_classes)

    print('args.model_dir', args.model_dir)
    sys.path.append(args.model_dir)
    from model import Model
    model_cfg = ModelConfig(**json_cfg)
    model = Model(model_cfg)
    setattr(model, 'stop_training', False)
    
    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        model.to_gpu()
    
    best_accuracy = 0.
    best_epoch = 0
    
    def keep_training(epoch, best_epoch):
        if model_cfg.n_epochs is not None and epoch > model_cfg.n_epochs:
                return False
        if epoch > 1 and epoch - best_epoch > model_cfg.patience:
            return False
        return True
    
    epoch = 1
    
    while True:
        if not keep_training(epoch, best_epoch):
            break
    
        if args.shuffle:
            perm = np.random.permutation(N)
        else:
            perm = np.arange(N)
    
        sum_accuracy = 0
        sum_loss = 0

        pbar = progressbar.ProgressBar(term_width=40,
            widgets=[' ', progressbar.Percentage(),
            ' ', progressbar.ETA()],
            maxval=N).start()

        for j, i in enumerate(six.moves.range(0, N, model_cfg.batch_size)):
            pbar.update(j+1)
            x_batch = xp.asarray(x_train[perm[i:i + model_cfg.batch_size]].flatten())
            y_batch = xp.asarray(y_train[perm[i:i + model_cfg.batch_size]])
            pred, loss, acc = model.fit(x_batch, y_batch)
            sum_loss += float(loss.data) * len(y_batch)
            sum_accuracy += float(acc.data) * len(y_batch)

        pbar.finish()
        print('train epoch={}, mean loss={}, accuracy={}'.format(
            epoch, sum_loss / N, sum_accuracy / N))
    
        # Validation set evaluation
        sum_accuracy = 0
        sum_loss = 0

        pbar = progressbar.ProgressBar(term_width=40,
            widgets=[' ', progressbar.Percentage(),
            ' ', progressbar.ETA()],
            maxval=N_validation).start()

        for i in six.moves.range(0, N_validation, model_cfg.batch_size):
            pbar.update(i+1)
            x_batch = xp.asarray(x_validation[i:i + model_cfg.batch_size].flatten())
            y_batch = xp.asarray(y_validation[i:i + model_cfg.batch_size])
            pred, loss, acc = model.predict(x_batch, target=y_batch)
            sum_loss += float(loss.data) * len(y_batch)
            sum_accuracy += float(acc.data) * len(y_batch)

        pbar.finish()
        validation_accuracy = sum_accuracy / N_validation
        validation_loss = sum_loss / N_validation
    
        if validation_accuracy > best_accuracy:
            best_accuracy = validation_accuracy
            best_epoch = epoch
            if model_path is not None:
                if args.gpu >= 0:
                    model.to_cpu()
                store = {
                        'args': args,
                        'model': model,
                    }
                cPickle.dump(store, open(model_path + '.store', 'w'))
                if args.gpu >= 0:
                    model.to_gpu()
    
        print('validation epoch={}, mean loss={}, accuracy={} best=[accuracy={} epoch={}]'.format(
            epoch, validation_loss, validation_accuracy, 
            best_accuracy,
            best_epoch))
    
        epoch += 1
示例#7
0
def main(args):
    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args, model_path)

    x_train, y_train = load_model_data(args.train_file, args.data_name,
                                       args.target_name)
    x_validation, y_validation = load_model_data(args.validation_file,
                                                 args.data_name,
                                                 args.target_name)

    rng = np.random.RandomState(args.seed)

    if args.n_classes > -1:
        n_classes = args.n_classes
    else:
        n_classes = max(y_train) + 1

    n_classes, target_names, class_weight = load_target_data(args, n_classes)

    if class_weight is None and args.class_weight_auto:
        n_samples = len(y_train)
        weights = float(n_samples) / (n_classes * np.bincount(y_train))
        if args.class_weight_exponent:
            weights = weights**args.class_weight_exponent
        class_weight = dict(zip(range(n_classes), weights))

    if args.verbose:
        logging.debug("n_classes {0} min {1} max {2}".format(
            n_classes, min(y_train), max(y_train)))

    y_train_one_hot = np_utils.to_categorical(y_train, n_classes)
    y_validation_one_hot = np_utils.to_categorical(y_validation, n_classes)

    if args.verbose:
        logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
        logging.debug("x_train " + str(x_train.shape))

    min_vocab_index = np.min(x_train)
    max_vocab_index = np.max(x_train)

    if args.verbose:
        logging.debug("min vocab index {0} max vocab index {1}".format(
            min_vocab_index, max_vocab_index))

    json_cfg = load_model_json(args, x_train, n_classes)

    if args.verbose:
        logging.debug("loading model")

    sys.path.append(args.model_dir)
    import model
    from model import build_model

    #######################################################################
    # Subsetting
    #######################################################################
    if args.subsetting_function:
        subsetter = getattr(M, args.subsetting_function)
    else:
        subsetter = None

    def take_subset(subsetter, path, x, y, y_one_hot, n):
        if subsetter is None:
            return x[0:n], y[0:n], y_one_hot[0:n]
        else:
            mask = subsetter(path)
            idx = np.where(mask)[0]
            idx = idx[0:n]
        return x[idx], y[idx], y_one_hot[idx]

    x_train, y_train, y_train_one_hot = take_subset(subsetter,
                                                    args.train_file,
                                                    x_train,
                                                    y_train,
                                                    y_train_one_hot,
                                                    n=args.n_train)

    x_validation, y_validation, y_validation_one_hot = take_subset(
        subsetter,
        args.validation_file,
        x_validation,
        y_validation,
        y_validation_one_hot,
        n=args.n_validation)

    #######################################################################
    # Preprocessing
    #######################################################################
    if args.preprocessing_class:
        preprocessor = getattr(M, args.preprocessing_class)(seed=args.seed)
    else:
        preprocessor = modeling.preprocess.NullPreprocessor()

    if args.verbose:
        logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
        logging.debug("x_train " + str(x_train.shape))

    model_cfg = ModelConfig(**json_cfg)
    if args.verbose:
        logging.info("model_cfg " + str(model_cfg))
    net = build_model(model_cfg)
    setattr(net, 'stop_training', False)

    marshaller = None
    if isinstance(net, keras.models.Graph):
        marshaller = getattr(model, args.graph_marshalling_class)()

    logging.info('model has {n_params} parameters'.format(
        n_params=count_parameters(net)))

    if len(args.extra_train_file) > 1:
        callbacks = keras.callbacks.CallbackList()
    else:
        callbacks = []

    save_model_info(args, model_path, model_cfg)

    callback_logger = logging.info if args.log else callable_print

    #######################################################################
    # Callbacks that need validation set predictions.
    #######################################################################

    pc = PredictionCallback(x_validation,
                            callback_logger,
                            marshaller=marshaller,
                            batch_size=model_cfg.batch_size)
    callbacks.append(pc)

    if args.classification_report:
        cr = ClassificationReport(x_validation,
                                  y_validation,
                                  callback_logger,
                                  target_names=target_names)
        pc.add(cr)

    if args.confusion_matrix:
        cm = ConfusionMatrix(x_validation, y_validation, callback_logger)
        pc.add(cm)

    def get_mode(metric_name):
        return {
            'val_loss': 'min',
            'val_acc': 'max',
            'val_f1': 'max',
            'val_f2': 'max',
            'val_f0.5': 'max'
        }[metric_name]

    if args.early_stopping or args.early_stopping_metric is not None:
        es = EarlyStopping(monitor=args.early_stopping_metric,
                           mode=get_mode(args.early_stopping_metric),
                           patience=model_cfg.patience,
                           verbose=1)
        cb = DelegatingMetricCallback(x_validation,
                                      y_validation,
                                      callback_logger,
                                      delegate=es,
                                      metric_name=args.early_stopping_metric,
                                      marshaller=marshaller)
        pc.add(cb)

    if not args.no_save:
        if args.save_all_checkpoints:
            filepath = model_path + '/model-{epoch:04d}.h5'
        else:
            filepath = model_path + '/model.h5'
        mc = ModelCheckpoint(filepath=filepath,
                             mode=get_mode(args.checkpoint_metric),
                             verbose=1,
                             monitor=args.checkpoint_metric,
                             save_best_only=not args.save_every_epoch)
        cb = DelegatingMetricCallback(x_validation,
                                      y_validation,
                                      callback_logger,
                                      delegate=mc,
                                      metric_name=args.checkpoint_metric,
                                      marshaller=marshaller)
        pc.add(cb)

    if model_cfg.optimizer == 'SGD':
        callbacks.append(SingleStepLearningRateSchedule(patience=10))

    if len(args.extra_train_file) > 1:
        args.extra_train_file.append(args.train_file)
        logging.info("Using the following files for training: " +
                     ','.join(args.extra_train_file))

        train_file_iter = itertools.cycle(args.extra_train_file)
        current_train = args.train_file

        callbacks._set_model(net)
        callbacks.on_train_begin(logs={})

        epoch = batch = 0

        while True:
            x_train, y_train_one_hot = preprocessor.fit_transform(
                x_train, y_train_one_hot)
            x_validation, y_validation_one_hot = preprocessor.transform(
                x_validation, y_validation_one_hot)

            iteration = batch % len(args.extra_train_file)

            logging.info(
                "epoch {epoch} iteration {iteration} - training with {train_file}"
                .format(epoch=epoch,
                        iteration=iteration,
                        train_file=current_train))
            callbacks.on_epoch_begin(epoch, logs={})

            n_train = x_train.shape[0]

            callbacks.on_batch_begin(batch, logs={'size': n_train})

            index_array = np.arange(n_train)
            if args.shuffle:
                rng.shuffle(index_array)

            batches = keras.models.make_batches(n_train, model_cfg.batch_size)
            logging.info(
                "epoch {epoch} iteration {iteration} - starting {n_batches} batches"
                .format(epoch=epoch,
                        iteration=iteration,
                        n_batches=len(batches)))

            avg_train_loss = avg_train_accuracy = 0.
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]

                if isinstance(net, keras.models.Graph):
                    train_data = marshaller.marshal(x_train[batch_ids],
                                                    y_train_one_hot[batch_ids])
                    train_loss = net.train_on_batch(train_data,
                                                    class_weight=class_weight)
                    # It looks like train_on_batch returns a different
                    # type for graph than sequential models.
                    train_loss = train_loss[0]
                    train_accuracy = 0.
                else:
                    train_loss, train_accuracy = net.train_on_batch(
                        x_train[batch_ids],
                        y_train_one_hot[batch_ids],
                        accuracy=True,
                        class_weight=class_weight)

                batch_end_logs = {
                    'loss': train_loss,
                    'accuracy': train_accuracy
                }

                avg_train_loss = (avg_train_loss * batch_index +
                                  train_loss) / (batch_index + 1)
                avg_train_accuracy = (avg_train_accuracy * batch_index +
                                      train_accuracy) / (batch_index + 1)

                callbacks.on_batch_end(batch,
                                       logs={
                                           'loss': train_loss,
                                           'accuracy': train_accuracy
                                       })

            logging.info(
                "epoch {epoch} iteration {iteration} - finished {n_batches} batches"
                .format(epoch=epoch,
                        iteration=iteration,
                        n_batches=len(batches)))

            logging.info(
                "epoch {epoch} iteration {iteration} - loss: {loss} - acc: {acc}"
                .format(epoch=epoch,
                        iteration=iteration,
                        loss=avg_train_loss,
                        acc=avg_train_accuracy))

            batch += 1

            # Validation frequency (this if-block) doesn't necessarily
            # occur in the same iteration as beginning of an epoch
            # (next if-block), so net.evaluate appears twice here.
            kwargs = {
                'batch_size': model_cfg.batch_size,
                'verbose': 0 if args.log else 1
            }
            pargs = []
            validation_data = {}
            if isinstance(net, keras.models.Graph):
                validation_data = marshaller.marshal(x_validation,
                                                     y_validation_one_hot)
                pargs = [validation_data]
            else:
                pargs = [x_validation, y_validation_one_hot]
                kwargs['show_accuracy'] = True

            if (iteration + 1) % args.validation_freq == 0:
                if isinstance(net, keras.models.Graph):
                    val_loss = net.evaluate(*pargs, **kwargs)
                    y_hat = net.predict(validation_data,
                                        batch_size=model_cfg.batch_size)
                    val_acc = accuracy_score(
                        y_validation, np.argmax(y_hat['output'], axis=1))
                else:
                    val_loss, val_acc = net.evaluate(*pargs, **kwargs)
                logging.info(
                    "epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}"
                    .format(epoch=epoch,
                            iteration=iteration,
                            val_loss=val_loss,
                            val_acc=val_acc))
                epoch_end_logs = {
                    'iteration': iteration,
                    'val_loss': val_loss,
                    'val_acc': val_acc
                }
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if batch % len(args.extra_train_file) == 0:
                if isinstance(net, keras.models.Graph):
                    val_loss = net.evaluate(*pargs, **kwargs)
                    y_hat = net.predict(validation_data,
                                        batch_size=model_cfg.batch_size)
                    val_acc = accuracy_score(
                        y_validation, np.argmax(y_hat['output'], axis=1))
                else:
                    val_loss, val_acc = net.evaluate(*pargs, **kwargs)
                logging.info(
                    "epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}"
                    .format(epoch=epoch,
                            iteration=iteration,
                            val_loss=val_loss,
                            val_acc=val_acc))
                epoch_end_logs = {
                    'iteration': iteration,
                    'val_loss': val_loss,
                    'val_acc': val_acc
                }
                epoch += 1
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if net.stop_training:
                logging.info(
                    "epoch {epoch} iteration {iteration} - done training".
                    format(epoch=epoch, iteration=iteration))
                break

            current_train = next(train_file_iter)
            x_train, y_train = load_model_data(current_train, args.data_name,
                                               args.target_name)
            y_train_one_hot = np_utils.to_categorical(y_train, n_classes)

            if epoch > args.n_epochs:
                break

        callbacks.on_train_end(logs={})
    else:
        x_train, y_train_one_hot = preprocessor.fit_transform(
            x_train, y_train_one_hot)
        x_validation, y_validation_one_hot = preprocessor.transform(
            x_validation, y_validation_one_hot)

        if isinstance(net, keras.models.Graph):
            train_data = marshaller.marshal(x_train, y_train_one_hot)
            validation_data = marshaller.marshal(x_validation,
                                                 y_validation_one_hot)
            net.fit(train_data,
                    shuffle=args.shuffle,
                    nb_epoch=args.n_epochs,
                    batch_size=model_cfg.batch_size,
                    validation_data=validation_data,
                    callbacks=callbacks,
                    class_weight=class_weight,
                    verbose=2 if args.log else 1)
        else:
            net.fit(x_train,
                    y_train_one_hot,
                    shuffle=args.shuffle,
                    nb_epoch=args.n_epochs,
                    batch_size=model_cfg.batch_size,
                    show_accuracy=True,
                    validation_data=(x_validation, y_validation_one_hot),
                    callbacks=callbacks,
                    class_weight=class_weight,
                    verbose=2 if args.log else 1)
示例#8
0
def main(args):
    model_id = build_model_id(args)
    model_path = build_model_path(args, model_id)
    setup_model_dir(args, model_path)
    sys.stdout, sys.stderr = setup_logging(args, model_path)

    x_train, y_train = load_model_data(args.train_file, args.data_name,
                                       args.target_name)
    x_validation, y_validation = load_model_data(args.validation_file,
                                                 args.data_name,
                                                 args.target_name)

    rng = np.random.RandomState(args.seed)

    if args.n_classes > -1:
        n_classes = args.n_classes
    else:
        n_classes = max(y_train) + 1

    n_classes, target_names, class_weight = load_target_data(args, n_classes)

    logging.debug("n_classes {0} min {1} max {2}".format(
        n_classes, min(y_train), max(y_train)))

    y_train_one_hot = np_utils.to_categorical(y_train, n_classes)
    y_validation_one_hot = np_utils.to_categorical(y_validation, n_classes)

    logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
    logging.debug("x_train " + str(x_train.shape))

    min_vocab_index = np.min(x_train)
    max_vocab_index = np.max(x_train)
    logging.debug("min vocab index {0} max vocab index {1}".format(
        min_vocab_index, max_vocab_index))

    json_cfg = load_model_json(args, x_train, n_classes)

    logging.debug("loading model")

    sys.path.append(args.model_dir)
    import model
    from model import build_model

    if args.subsetting_function:
        subsetter = getattr(model, args.subsetting_function)
    else:
        subsetter = None

    def take_subset(subsetter, path, x, y, y_one_hot, n):
        if subsetter is None:
            return x[0:n], y[0:n], y_one_hot[0:n]
        else:
            mask = subsetter(path)
            idx = np.where(mask)[0]
            idx = idx[0:n]
        return x[idx], y[idx], y_one_hot[idx]

    x_train, y_train, y_train_one_hot = take_subset(subsetter,
                                                    args.train_file,
                                                    x_train,
                                                    y_train,
                                                    y_train_one_hot,
                                                    n=args.n_train)

    x_validation, y_validation, y_validation_one_hot = take_subset(
        subsetter,
        args.validation_file,
        x_validation,
        y_validation,
        y_validation_one_hot,
        n=args.n_validation)

    logging.debug("y_train_one_hot " + str(y_train_one_hot.shape))
    logging.debug("x_train " + str(x_train.shape))

    model_cfg = ModelConfig(**json_cfg)
    logging.info("model_cfg " + str(model_cfg))
    model = build_model(model_cfg)
    setattr(model, 'stop_training', False)

    logging.info('model has {n_params} parameters'.format(
        n_params=count_parameters(model)))

    if len(args.extra_train_file) > 1:
        callbacks = keras.callbacks.CallbackList()
    else:
        callbacks = []

    save_model_info(args, model_path, model_cfg)

    if not args.no_save:
        callbacks.append(
            ModelCheckpoint(filepath=model_path + '/model-{epoch:04d}.h5',
                            verbose=1,
                            save_best_only=True))

    callback_logger = logging.info if args.log else callable_print

    if args.n_epochs < sys.maxsize:
        # Number of epochs overrides patience.  If the number of epochs
        # is specified on the command line, the model is trained for
        # exactly that number; otherwise, the model is trained with
        # early stopping using the patience specified in the model
        # configuration.
        callbacks.append(
            EarlyStopping(monitor='val_loss',
                          patience=model_cfg.patience,
                          verbose=1))

    if args.classification_report:
        cr = ClassificationReport(x_validation,
                                  y_validation,
                                  callback_logger,
                                  target_names=target_names)
        callbacks.append(cr)

    if model_cfg.optimizer == 'SGD':
        callbacks.append(SingleStepLearningRateSchedule(patience=10))

    if len(args.extra_train_file) > 1:
        args.extra_train_file.append(args.train_file)
        logging.info("Using the following files for training: " +
                     ','.join(args.extra_train_file))

        train_file_iter = itertools.cycle(args.extra_train_file)
        current_train = args.train_file

        callbacks._set_model(model)
        callbacks.on_train_begin(logs={})

        epoch = batch = 0

        while True:
            iteration = batch % len(args.extra_train_file)

            logging.info(
                "epoch {epoch} iteration {iteration} - training with {train_file}"
                .format(epoch=epoch,
                        iteration=iteration,
                        train_file=current_train))
            callbacks.on_epoch_begin(epoch, logs={})

            n_train = x_train.shape[0]

            callbacks.on_batch_begin(batch, logs={'size': n_train})

            index_array = np.arange(n_train)
            if args.shuffle:
                rng.shuffle(index_array)

            batches = keras.models.make_batches(n_train, model_cfg.batch_size)
            logging.info(
                "epoch {epoch} iteration {iteration} - starting {n_batches} batches"
                .format(epoch=epoch,
                        iteration=iteration,
                        n_batches=len(batches)))

            avg_train_loss = avg_train_accuracy = 0.
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]

                if isinstance(model, keras.models.Graph):
                    data = {
                        'input': x_train[batch_ids],
                        'output': y_train_one_hot[batch_ids]
                    }
                    train_loss = model.train_on_batch(
                        data, class_weight=class_weight)
                    train_accuracy = 0.
                else:
                    train_loss, train_accuracy = model.train_on_batch(
                        x_train[batch_ids],
                        y_train_one_hot[batch_ids],
                        accuracy=True,
                        class_weight=class_weight)

                batch_end_logs = {
                    'loss': train_loss,
                    'accuracy': train_accuracy
                }

                avg_train_loss = (avg_train_loss * batch_index +
                                  train_loss) / (batch_index + 1)
                avg_train_accuracy = (avg_train_accuracy * batch_index +
                                      train_accuracy) / (batch_index + 1)

                callbacks.on_batch_end(batch,
                                       logs={
                                           'loss': train_loss,
                                           'accuracy': train_accuracy
                                       })

            logging.info(
                "epoch {epoch} iteration {iteration} - finished {n_batches} batches"
                .format(epoch=epoch,
                        iteration=iteration,
                        n_batches=len(batches)))

            logging.info(
                "epoch {epoch} iteration {iteration} - loss: {loss} - acc: {acc}"
                .format(epoch=epoch,
                        iteration=iteration,
                        loss=avg_train_loss,
                        acc=avg_train_accuracy))

            batch += 1

            # Validation frequency (this if-block) doesn't necessarily
            # occur in the same iteration as beginning of an epoch
            # (next if-block), so model.evaluate appears twice here.
            if (iteration + 1) % args.validation_freq == 0:
                val_loss, val_acc = model.evaluate(
                    x_validation,
                    y_validation_one_hot,
                    show_accuracy=True,
                    verbose=0 if args.log else 1)
                logging.info(
                    "epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}"
                    .format(epoch=epoch,
                            iteration=iteration,
                            val_loss=val_loss,
                            val_acc=val_acc))
                epoch_end_logs = {
                    'iteration': iteration,
                    'val_loss': val_loss,
                    'val_acc': val_acc
                }
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if batch % len(args.extra_train_file) == 0:
                val_loss, val_acc = model.evaluate(
                    x_validation,
                    y_validation_one_hot,
                    show_accuracy=True,
                    verbose=0 if args.log else 1)
                logging.info(
                    "epoch {epoch} iteration {iteration} - val_loss: {val_loss} - val_acc: {val_acc}"
                    .format(epoch=epoch,
                            iteration=iteration,
                            val_loss=val_loss,
                            val_acc=val_acc))
                epoch_end_logs = {
                    'iteration': iteration,
                    'val_loss': val_loss,
                    'val_acc': val_acc
                }
                epoch += 1
                callbacks.on_epoch_end(epoch, epoch_end_logs)

            if model.stop_training:
                logging.info(
                    "epoch {epoch} iteration {iteration} - done training".
                    format(epoch=epoch, iteration=iteration))
                break

            current_train = next(train_file_iter)
            x_train, y_train = load_model_data(current_train, args.data_name,
                                               args.target_name)
            y_train_one_hot = np_utils.to_categorical(y_train, n_classes)

            if epoch > args.n_epochs:
                break

        callbacks.on_train_end(logs={})
    else:
        print('args.n_epochs', args.n_epochs)
        if isinstance(model, keras.models.Graph):
            data = {'input': x_train, 'output': y_train_one_hot}
            validation_data = {
                'input': x_validation,
                'output': y_validation_one_hot
            }
            model.fit(
                data,
                shuffle=args.shuffle,
                nb_epoch=args.n_epochs,
                batch_size=model_cfg.batch_size,
                #show_accuracy=True,
                validation_data=validation_data,
                callbacks=callbacks,
                class_weight=class_weight,
                verbose=2 if args.log else 1)
            y_hat = model.predict_classes(data)
            print('val_acc %.04f' % accuracy_score(y_validate, y_hat))
        else:
            model.fit(x_train,
                      y_train_one_hot,
                      shuffle=args.shuffle,
                      nb_epoch=args.n_epochs,
                      batch_size=model_cfg.batch_size,
                      show_accuracy=True,
                      validation_data=(x_validation, y_validation_one_hot),
                      callbacks=callbacks,
                      class_weight=class_weight,
                      verbose=2 if args.log else 1)