def my_layer():
    """test one specify layer"""
    a = Input(shape=(3, 3, 2))
    b = WeightedAdd()(a)
    model = Model(inputs=a, outputs=b)
    data = np.ones((1, 3, 3, 2))
    print(model.predict_on_batch(data))
    model.compile(optimizer='Adam', loss=mean_squared_error)
    model.fit(data, data, epochs=1000)
    print(model.predict_on_batch(data))
Example #2
0
def test_deeper_conv_block():
    model = get_conv_model()
    layers = deeper_conv_block(model.layers[1], 3)
    output_tensor = layers[0](model.outputs[0])
    output_tensor = layers[1](output_tensor)
    output_tensor = layers[2](output_tensor)
    new_model = Model(inputs=model.inputs, outputs=output_tensor)
    input_data = get_conv_data()
    output1 = model.predict_on_batch(input_data).flatten()
    output2 = new_model.predict_on_batch(input_data).flatten()
    assert np.sum(np.abs(output1 - output2)) < 1e-1
Example #3
0
def test_wider_conv():
    model = get_conv_model()

    layer1 = wider_pre_conv(model.layers[1], 3)
    layer2 = wider_bn(model.layers[2], 3, 3, 3)
    layer3 = wider_next_conv(model.layers[4], 3, 3, 3)

    input_tensor = Input(shape=(5, 5, 3))
    output_tensor = layer1(input_tensor)
    output_tensor = layer2(output_tensor)
    output_tensor = Activation('relu')(output_tensor)
    output_tensor = layer3(output_tensor)
    output_tensor = BatchNormalization()(output_tensor)
    output_tensor = Activation('relu')(output_tensor)
    model2 = Model(inputs=input_tensor, outputs=output_tensor)

    random_input = get_conv_data()
    output1 = model.predict_on_batch(random_input)
    output2 = model2.predict_on_batch(random_input)
    assert np.sum(np.abs(output1.flatten() - output2.flatten())) < 1e-1
Example #4
0
def run(dataset_seed, image_shape, batch_size, device, data_dir, output_dir,
        phases, architecture,
        o_meta, limb_weights, joint_weights, weights, pooling,
        dense_layers, use_gram_matrix, last_base_layer, override,
        embedded_files_max_size, selected_layers):
    os.makedirs(output_dir, exist_ok=True)

    with tf.device(device):
        print('building model...')
        model = build_siamese_model(image_shape, architecture, 0.0, weights,
                                    last_base_layer=last_base_layer,
                                    use_gram_matrix=use_gram_matrix,
                                    dense_layers=dense_layers, pooling=pooling,
                                    include_base_top=False, include_top=True,
                                    trainable_limbs=False,
                                    limb_weights=limb_weights,
                                    predictions_activation=[o['a'] for o in o_meta],
                                    predictions_name=[o['n'] for o in o_meta],
                                    classes=[o['u'] for o in o_meta],
                                    embedding_units=[o['e'] for o in o_meta],
                                    joints=[o['j'] for o in o_meta])
        # Restore best parameters.
        print('loading weights from:', joint_weights)
        model.load_weights(joint_weights)
        model = model.get_layer('model_2')

        available_layers = [l.name for l in model.layers]
        if set(selected_layers) - set(available_layers):
            print('available layers:', available_layers)
            raise ValueError('selection contains unknown layers: %s' % selected_layers)

        style_features = [model.get_layer(l).output for l in selected_layers]

        if use_gram_matrix:
            gram_layer = layers.Lambda(gram_matrix, arguments=dict(norm_by_channels=False))
            style_features = [gram_layer(f) for f in style_features]

        model = Model(inputs=model.inputs, outputs=style_features)

    g = ImageDataGenerator(preprocessing_function=get_preprocess_fn(architecture))

    for phase in phases:
        phase_data_dir = os.path.join(data_dir, phase)
        output_file_name = os.path.join(output_dir, phase + '.%i.pickle')
        already_embedded = os.path.exists(output_file_name % 0)
        phase_exists = os.path.exists(phase_data_dir)

        if already_embedded and not override or not phase_exists:
            print('%s transformation skipped' % phase)
            continue

        # Shuffle must always be off in order to keep names consistent.
        data = g.flow_from_directory(phase_data_dir,
                                     target_size=image_shape[:2],
                                     class_mode='sparse',
                                     batch_size=batch_size, shuffle=False,
                                     seed=dataset_seed)
        print('transforming %i %s samples from %s' % (data.n, phase, phase_data_dir))
        part_id = 0
        samples_seen = 0
        displayed_once = False

        while samples_seen < data.n:
            z, y = {n: [] for n in selected_layers}, []
            chunk_size = 0
            chunk_start = samples_seen

            while chunk_size < embedded_files_max_size and samples_seen < data.n:
                _x, _y = next(data)

                outputs = model.predict_on_batch(_x)
                chunk_size += sum(o.nbytes for o in outputs)

                for l, o in zip(selected_layers, outputs):
                    z[l].append(o)

                y.append(_y)
                samples_seen += _x.shape[0]
                chunk_p = int(100 * (samples_seen / data.n))

                if chunk_p % 10 == 0:
                    if not displayed_once:
                        print('\n%i%% (%.2f MB)'
                              % (chunk_p, chunk_size / 1024 ** 2),
                              flush=True, end='')
                        displayed_once = True
                else:
                    displayed_once = False
                    print('.', end='')

            for layer in selected_layers:
                z[layer] = np.concatenate(z[layer])

            with open(output_file_name % part_id, 'wb') as f:
                pickle.dump({'data': z,
                             'target': np.concatenate(y),
                             'names': np.asarray(data.filenames[chunk_start: samples_seen])},
                            f, pickle.HIGHEST_PROTOCOL)
            part_id += 1
    print('done.')
Example #5
0
class FergusRModel(object):
    def __init__(self, igor):

        now = datetime.now()
        self.run_name = "fergusr_{}mo_{}day_{}hr_{}min".format(
            now.month, now.day, now.hour, now.minute)
        log_location = join(igor.log_dir, self.run_name + ".log")
        self.logger = igor.logger = make_logger(igor, log_location)
        igor.verify_directories()
        self.igor = igor

    @classmethod
    def from_yaml(cls, yamlfile, kwargs=None):
        igor = Igor.from_file(yamlfile)
        model = cls(igor)

        igor.prep()
        model.make(kwargs)
        return model

    @classmethod
    def from_config(cls, config, kwargs=None):
        igor = Igor(config)
        model = cls(igor)

        igor.prep()
        model.make(kwargs)
        return model

    def load_checkpoint_weights(self):
        weight_file = join(self.igor.model_location, self.igor.saving_prefix,
                           self.igor.checkpoint_weights)
        if exists(weight_file):
            self.logger.info("+ Loading checkpoint weights")
            self.model.load_weights(weight_file, by_name=True)
        else:
            self.logger.warning(
                "- Checkpoint weights do not exist; {}".format(weight_file))

    def plot(self):
        filename = join(self.igor.model_location, self.igor.saving_prefix,
                        'model_visualization.png')
        kplot(self.model, to_file=filename)
        self.logger.debug("+ Model visualized at {}".format(filename))

    def make(self, theano_kwargs=None):
        """Construct the Fergus-Recurrent model
        
        Model: 
            Input at time t: 
                - Soft attention over embedded lexemes of children of node_t
                - Embedded lexeme of node_t
            Compute:
                - Inputs are fed into a recurrent tree s.t. hidden states travel down branches
                - node_t's supertag embeddings are retrieved
                - output of recurrent tree at time t is aligned with each supertag vector
                - a vectorized probability function computes a distribution
            Output:
                - Distribution over supertags for node_t
        """
        if self.igor.embedding_type == "convolutional":
            make_convolutional_embedding(self.igor)
        elif self.igor.embedding_type == "token":
            make_token_embedding(self.igor)
        elif self.igor.embedding_type == "shallowconv":
            make_shallow_convolutional_embedding(self.igor)
        elif self.igor.embedding_type == "minimaltoken":
            make_minimal_token_embedding(self.igor)
        else:
            raise Exception("Incorrect embedding type")

        spine_input_shape = (self.igor.batch_size, self.igor.max_sequence,
                             self.igor.max_num_supertags)

        node_input_shape = (self.igor.batch_size, self.igor.max_sequence)

        dctx_input_shape = (self.igor.batch_size, self.igor.max_sequence,
                            self.igor.max_daughter_size)

        E, V = self.igor.word_embedding_size, self.igor.word_vocab_size  # for word embeddings
        repeat_N = self.igor.max_num_supertags  # for lex
        repeat_D = self.igor.max_daughter_size
        mlp_size = self.igor.mlp_size

        ## dropout parameters
        p_emb = self.igor.p_emb_dropout
        p_W = self.igor.p_W_dropout
        p_U = self.igor.p_U_dropout
        w_decay = self.igor.weight_decay
        p_mlp = self.igor.p_mlp_dropout

        #### make layer inputs
        spineset_in = Input(batch_shape=spine_input_shape,
                            name='parent_spineset_in',
                            dtype='int32')
        phead_in = Input(batch_shape=node_input_shape,
                         name='parent_head_input',
                         dtype='int32')
        dctx_in = Input(batch_shape=dctx_input_shape,
                        name='daughter_context_input',
                        dtype='int32')
        topology_in = Input(batch_shape=node_input_shape,
                            name='node_topology',
                            dtype='int32')

        ##### params
        def predict_params():
            return {
                'output_dim': 1,
                'W_regularizer': l2(w_decay),
                'activation': 'relu',
                'b_regularizer': l2(w_decay)
            }

        ### Layer functions
        ############# Convert the word indices to vectors
        F_embedword = Embedding(input_dim=V,
                                output_dim=E,
                                mask_zero=True,
                                W_regularizer=l2(w_decay),
                                dropout=p_emb,
                                name='embedword')
        if self.igor.saved_embeddings is not None:
            print("Loading saved embeddings....")
            F_embedword.initial_weights = [self.igor.saved_embeddings]

        F_probability = ProbabilityTensor(
            name='predictions', dense_function=Dense(**predict_params()))
        ### composition functions

        F_softdaughters = compose(
            LambdaMask(lambda x, mask: None, name='remove_attention_mask'),
            Distribute(SoftAttention(name='softdaughter'),
                       name='distribute_softdaughter'), F_embedword)

        F_align = compose(Distribute(Dropout(p_mlp)),
                          Distribute(Dense(mlp_size, activation='relu')),
                          concat)

        F_rtn = compose(
            RepeatVector(repeat_N, axis=2, name='repeattree'),
            BranchLSTM(self.igor.rtn_size,
                       name='recurrent_tree1',
                       return_sequences=True))

        F_predict = compose(
            Distribute(F_probability, name='distribute_probability'),
            Distribute(
                Dropout(p_mlp)
            ),  ### need a separate one because the 'concat' is different for the two situations
            LastDimDistribute(Dense(mlp_size, activation='relu')),
            concat)

        ############################ new ###########################

        dctx = F_softdaughters(dctx_in)
        parent = F_embedword(phead_in)
        #node_context = F_align([parent, dctx])
        #import pdb
        #pdb.set_trace()

        ### put into tree
        aligned_node = F_align([parent, dctx])
        node_context = F_rtn([aligned_node, topology_in])

        parent_spines = self.igor.F_embedspine(spineset_in)
        ### get probability
        predictions = F_predict([node_context, parent_spines])

        ##################
        ### make model
        ##################
        self.model = Model(input=[dctx_in, phead_in, topology_in, spineset_in],
                           output=predictions,
                           preloaded_data=self.igor.preloaded_data)

        ##################
        ### compile model
        ##################
        optimizer = Adam(self.igor.LR,
                         clipnorm=self.igor.max_grad_norm,
                         clipvalue=self.igor.grad_clip_threshold)
        theano_kwargs = theano_kwargs or {}
        self.model.compile(loss='categorical_crossentropy',
                           optimizer=optimizer,
                           metrics=['accuracy'],
                           **theano_kwargs)

        if self.igor.from_checkpoint:
            self.load_checkpoint_weights()
        elif not self.igor.in_training:
            raise Exception("No point in running this without trained weights")

    def train(self):
        train_data = self.igor.train_gen(forever=True)
        dev_data = self.igor.dev_gen(forever=True)
        N = self.igor.num_train_samples
        E = self.igor.num_epochs
        # generator, samplers per epoch, number epochs
        callbacks = [ProgbarV2(3, 10)]

        checkpoint_fp = join(self.igor.model_location, self.igor.saving_prefix,
                             self.igor.checkpoint_weights)
        self.logger.info("+ Model Checkpoint: {}".format(checkpoint_fp))

        callbacks += [
            ModelCheckpoint(filepath=checkpoint_fp,
                            verbose=1,
                            save_best_only=True)
        ]
        callbacks += [
            LearningRateScheduler(lambda epoch: self.igor.LR * 0.9**(epoch))
        ]

        csv_location = join(self.igor.log_dir, self.run_name + ".csv")
        callbacks += [CSVLogger(csv_location)]
        self.model.fit_generator(generator=train_data,
                                 samples_per_epoch=N,
                                 nb_epoch=E,
                                 callbacks=callbacks,
                                 verbose=1,
                                 validation_data=dev_data,
                                 nb_val_samples=self.igor.num_dev_samples)

    def debug(self):
        dev_data = self.igor.dev_gen(forever=False)
        X, Y = next(dev_data)
        self.model.predict_on_batch(X)
        #self.model.evaluate_generator(dev_data, self.igor.num_dev_samples)

    def profile(self, num_iterations=1):
        train_data = self.igor.train_gen(forever=True)
        dev_data = self.igor.dev_gen(forever=True)
        # generator, samplers per epoch, number epochs
        callbacks = [ProgbarV2(1, 10)]
        self.logger.debug("+ Beginning the generator")
        self.model.fit_generator(generator=train_data,
                                 samples_per_epoch=self.igor.batch_size * 10,
                                 nb_epoch=num_iterations,
                                 callbacks=callbacks,
                                 verbose=1,
                                 validation_data=dev_data,
                                 nb_val_samples=self.igor.batch_size)
        self.logger.debug(
            "+ Calling theano's pydot print.. this might take a while")
        theano.printing.pydotprint(self.model.train_function.function,
                                   outfile='theano_graph.png',
                                   var_with_name_simple=True,
                                   with_ids=True)
        self.logger.debug("+ Calling keras' print.. this might take a while")
        self.plot("keras_graph.png")
        #self.model.profile.print_summary()

    def __call__(self, data):
        if self.model is None:
            raise Exception("model not instantiated yet; please call make()")
        assert isinstance(data, list)
        B = data[0].shape[0]
        return self.model.predict(data, batch_size=B)
Example #6
0
class Transformer(object):
    def __init__(self,
                 src_dict,
                 tar_dict=None,
                 length_limit=70,
                 num_layers=6,
                 model_dim=512,
                 num_head=8,
                 head_dim=None,
                 inner_dim=2048,
                 dropout=0.1,
                 use_pos_embedding=True,
                 share_embedding=False,
                 inputs=None,
                 outputs=None,
                 name="Transformer"):
        if inputs is not None and outputs is not None:
            super(Transformer, self).__init__(inputs=inputs,
                                              outputs=outputs,
                                              name=name)
            return

        self.src_dict = src_dict
        self.tar_dict = tar_dict if tar_dict is not None else self.src_dict
        self.src_token_dict = {v: k for k, v in self.src_dict.items()}
        self.tar_token_dict = {v: k for k, v in self.tar_dict.items()}
        self.scr_dict_size = len(self.src_dict)
        self.tar_dict_size = len(
            self.tar_dict) if tar_dict is not None else self.scr_dict_size

        self.length_limit = length_limit

        self.num_layers = num_layers
        self.num_head = num_head
        self.model_dim = model_dim
        self.head_dim = head_dim if head_dim is not None else int(model_dim /
                                                                  num_head)
        self.inner_dim = inner_dim

        self.dropout = dropout
        self.use_pos_embedding = use_pos_embedding
        self.share_embedding = share_embedding

        self.source_embedding = None
        self.target_embedding = None
        self.position_embedding = None
        self.encoder = None
        self.decoder = None
        self.softmax = None
        self.model = None
        self.output_model = None

        self.decode_build = False
        self.encoder_model = None
        self.decoder_model = None

    def compile(self, optimizer="adam"):
        source_input = Input(shape=(None, ), dtype="int32")
        target_input = Input(shape=(None, ), dtype="int32")

        target_decode_in = Lambda(lambda x: K.slice(
            x,
            start=[0, 0],
            size=[K.shape(target_input)[0],
                  K.shape(target_input)[1] - 1]))(target_input)
        target_decode_out = Lambda(lambda x: K.slice(
            x,
            start=[0, 1],
            size=[K.shape(target_input)[0],
                  K.shape(target_input)[1] - 1]))(target_input)

        src_mask = Lambda(lambda x: get_mask_seq2seq(x, x))(source_input)
        tar_mask = Lambda(lambda x: self.get_self_mask(x))(target_decode_in)
        encode_mask = Lambda(lambda x: get_mask_seq2seq(x[0], x[1]))(
            [target_decode_in, source_input])

        self.source_embedding = Embedding(input_dim=self.scr_dict_size,
                                          output_dim=self.model_dim)
        if self.share_embedding:
            self.target_embedding = self.source_embedding
        else:
            self.target_embedding = Embedding(input_dim=self.tar_dict_size,
                                              output_dim=self.model_dim)

        if self.use_pos_embedding:
            self.position_embedding = PositionEmbedding(mode="sum")

        src_x = self.source_embedding(source_input)
        if self.use_pos_embedding:
            src_x = self.position_embedding(src_x)

        src_x = Dropout(self.dropout)(src_x)

        self.encoder = Encode(num_layers=self.num_layers,
                              num_head=self.num_head,
                              head_dim=self.head_dim,
                              model_dim=self.model_dim,
                              inner_dim=self.inner_dim,
                              dropout=self.dropout)
        encoder_output = self.encoder(src_x, masks=src_mask)

        tar_x = self.target_embedding(target_decode_in)
        if self.use_pos_embedding:
            tar_x = self.position_embedding(tar_x)

        self.decoder = Decode(num_layers=self.num_layers,
                              num_head=self.num_head,
                              head_dim=self.head_dim,
                              model_dim=self.model_dim,
                              inner_dim=self.inner_dim,
                              dropout=self.dropout)
        decoder_output = self.decoder([tar_x, encoder_output],
                                      self_mask=tar_mask,
                                      encode_mask=encode_mask)

        self.softmax = TimeDistributed(Dense(self.tar_dict_size))

        output = self.softmax(decoder_output)

        loss = Lambda(lambda x: self._get_loss(*x))(
            [output, target_decode_out])

        self.model = Model([source_input, target_input], loss)
        self.model.add_loss([loss])
        self.model.compile(optimizer, None)

        self.model.metrics_names.append("ppl")
        self.model.metrics_tensors.append(Lambda(K.exp)(loss))
        self.model.metrics_names.append("accuracy")
        self.model.metrics_tensors.append(
            Lambda(lambda x: self._get_acc(x[0], x[1]))(
                [output, target_decode_out]))

        self.output_model = Model([source_input, target_input], output)

    @staticmethod
    def get_encode_mask(src_seq):
        return get_mask_seq2seq(src_seq, src_seq)

    @staticmethod
    def get_self_mask(tar_seq):
        self_mask1 = get_mask_seq2seq(tar_seq, tar_seq)
        self_mask2 = get_mask_self(tar_seq)
        return K.minimum(self_mask1, self_mask2)

    @staticmethod
    def _get_loss(y_pred, y_true):
        y_true = tf.cast(y_true, dtype="int32")
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true,
                                                              logits=y_pred)
        mask = tf.cast(tf.not_equal(y_true, 0), dtype="float32")
        loss = tf.reduce_sum(loss * mask, -1) / tf.reduce_sum(mask, -1)
        return tf.reduce_mean(loss)

    @staticmethod
    def _get_acc(y_pred, y_true):
        mask = tf.cast(tf.not_equal(y_true, 0), dtype="float32")
        corr = K.cast(K.equal(K.cast(y_true, dtype="int32"),
                              K.cast(K.argmax(y_pred, -1), dtype="int32")),
                      dtype="float32")
        acc = K.sum(corr * mask, -1) / K.sum(mask, -1)
        return K.mean(acc)

    def decode_fast(self, seq):
        decode_tokens = []
        target_seq = np.zeros(shape=(1, self.length_limit), dtype=np.int32)
        target_seq[0, 0] = 2
        for i in range(self.length_limit - 1):
            output = self.output_model.predict_on_batch([seq, target_seq])
            max_prob_index = np.argmax(output[0, i, :])
            max_prob_token = self.tar_token_dict[max_prob_index]
            decode_tokens.append(max_prob_token)
            if max_prob_index == 3:
                break
            target_seq[0, i + 1] = max_prob_index
        return " ".join(decode_tokens)

    def _build_encoder(self):
        source_input = Input(shape=(None, ), dtype="int32")

        src_mask = Lambda(lambda x: get_mask_seq2seq(x, x))(source_input)

        src_x = self.source_embedding(source_input)
        if self.use_pos_embedding:
            src_x = self.position_embedding(src_x)

        encoder_output = self.encoder(src_x, masks=src_mask)
        self.encoder_model = Model([source_input], encoder_output)
        self.encoder_model.compile('adam', 'mse')

    def _build_decoder(self):
        source_input = Input(shape=(None, ), dtype="int32")
        target_input = Input(shape=(None, ), dtype="int32")
        encoder_output = Input(shape=(None, self.model_dim))

        tar_mask = Lambda(lambda x: self.get_self_mask(x))(target_input)
        encode_mask = Lambda(lambda x: get_mask_seq2seq(x[0], x[1]))(
            [target_input, source_input])

        tar_x = self.target_embedding(target_input)
        if self.use_pos_embedding:
            tar_x = self.position_embedding(tar_x)

        decoder_output = self.decoder([tar_x, encoder_output],
                                      self_mask=tar_mask,
                                      encode_mask=encode_mask)
        final_output = self.softmax(decoder_output)
        self.decoder_model = Model(
            [source_input, target_input, encoder_output], final_output)
        self.decoder_model.compile('adam', 'mse')

    def _build_decode_model(self):
        self._build_encoder()
        self._build_decoder()
        self.decode_build = True

    def decode(self, seq):
        if not self.decode_build:
            self._build_decode_model()

        decode_tokens = []
        target_seq = np.zeros(shape=(1, self.length_limit), dtype=np.int32)
        target_seq[0, 0] = 2

        encoder_output = self.encoder_model.predict_on_batch([seq])
        for i in range(self.length_limit - 1):
            output = self.decoder_model.predict_on_batch(
                [seq, target_seq, encoder_output])
            max_prob_index = np.argmax(output[0, i, :])
            max_prob_token = self.tar_token_dict[max_prob_index]
            decode_tokens.append(max_prob_token)
            if max_prob_index == 3:
                break
            target_seq[0, i + 1] = max_prob_index
        return " ".join(decode_tokens)

    def beam_search(self, seq, topk=3):
        if not self.decode_build:
            self._build_decode_model()

        seq = np.repeat(seq, topk, axis=0)
        encoder_output = self.encoder_model.predict_on_batch([seq])

        final_results = []
        topk_prob = np.zeros((topk, ), dtype=np.float32)
        decode_tokens = [[] for _ in range(topk)]

        target_seq = np.zeros((topk, self.length_limit), dtype=np.int32)
        target_seq[:, 0] = 2

        last_k = 1

        for i in range(self.length_limit - 1):
            if last_k == 0 or len(final_results) > topk * 3:
                break  # stop conditions

            target_output = self.decoder_model.predict_on_batch(
                [seq, target_seq, encoder_output])
            output = np.exp(target_output[:, i, :])
            output = output / np.sum(output, axis=-1, keepdims=True)
            output = np.log(
                output +
                1e-8)  # use `log` transformation to avoid tiny probability

            candidates = []

            for k, probs in zip(range(last_k), output):
                if target_seq[k, i] == 3:
                    continue

                word_p_sort = sorted(list(enumerate(probs)),
                                     key=lambda x: x[1],
                                     reverse=True)
                for ind, wp in word_p_sort[:topk]:
                    candidates.append((k, ind, topk_prob[k] + wp))

            candidates = sorted(candidates, key=lambda x: x[-1], reverse=True)
            candidates = candidates[:topk]

            target_seq_bk = target_seq.copy()

            for new_k, cand in enumerate(candidates):
                k, ind, seq_p = cand
                target_seq[new_k] = target_seq_bk[k]
                target_seq[new_k, i + 1] = ind
                topk_prob[new_k] = seq_p
                decode_tokens.append(decode_tokens[k] +
                                     [self.tar_token_dict[ind]])
                if ind == 3:
                    final_results.append((decode_tokens[k], seq_p))

            decode_tokens = decode_tokens[topk:]
            last_k = len(decode_tokens)

        final_results = [(x, y / (len(x) + 1)) for x, y in final_results]
        final_results = sorted(final_results, key=lambda x: x[1], reverse=True)
        return final_results
Example #7
0
def train_rpn(model_file=None):

    parser = OptionParser()
    parser.add_option("--train_path", dest="train_path", help="Path to training data.",
                      default='/Users/jie/projects/PanelSeg/ExpPython/train.txt')
    parser.add_option("--val_path", dest="val_path", help="Path to validation data.",
                      default='/Users/jie/projects/PanelSeg/ExpPython/eval.txt')
    parser.add_option("--num_rois", type="int", dest="num_rois", help="Number of RoIs to process at once.",
                      default=32)
    parser.add_option("--network", dest="network", help="Base network to use. Supports nn_cnn_3_layer.",
                      default='nn_cnn_3_layer')
    parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.",
                      default=2000)
    parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.",
                      default='./model_rpn.hdf5')
    parser.add_option("--input_weight_path", dest="input_weight_path",
                      default='/Users/jie/projects/PanelSeg/ExpPython/models/label+bg_rpn_3_layer_color-0.135.hdf5')

    (options, args) = parser.parse_args()

    # set configuration
    c = Config.Config()

    c.model_path = options.output_weight_path
    c.num_rois = int(options.num_rois)

    import nn_cnn_3_layer as nn

    c.base_net_weights = options.input_weight_path

    val_imgs, val_classes_count = get_label_rpn_data(options.val_path)
    train_imgs, train_classes_count = get_label_rpn_data(options.train_path)

    classes_count = {k: train_classes_count.get(k, 0) + val_classes_count.get(k, 0)
                     for k in set(train_classes_count) | set(val_classes_count)}
    class_mapping = LABEL_MAPPING

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    c.class_mapping = class_mapping

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))

    config_output_filename = 'config.pickle'

    with open(config_output_filename, 'wb') as config_f:
        pickle.dump(c, config_f)
        print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(
            config_output_filename))

    random.shuffle(train_imgs)
    random.shuffle(val_imgs)

    num_imgs = len(train_imgs) + len(val_imgs)

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = label_rcnn_data_generators.get_anchor_gt(
        train_imgs, classes_count, c, nn.nn_get_img_output_length, mode='train')
    data_gen_val = label_rcnn_data_generators.get_anchor_gt(
        val_imgs, classes_count, c, nn.nn_get_img_output_length, mode='val')

    input_shape_img = (None, None, 3)
    img_input = Input(shape=input_shape_img)
    # roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(c.anchor_box_scales) * len(c.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    # classifier = nn.classifier(shared_layers, roi_input, c.num_rois, nb_classes=len(classes_count), trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    # model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    # model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    print('loading weights from {}'.format(c.base_net_weights))
    model_rpn.load_weights(c.base_net_weights, by_name=True)
    # model_classifier.load_weights(c.base_net_weights, by_name=True)
    model_rpn.summary()

    optimizer = Adam(lr=1e-5)
    # optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer, loss=[nn.rpn_loss_cls(num_anchors), nn.rpn_loss_regr(num_anchors)])
    # model_classifier.compile(optimizer=optimizer_classifier,
    #                          loss=[nn.class_loss_cls, nn.class_loss_regr(len(classes_count) - 1)],
    #                          metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    # model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 1000
    num_epochs = int(options.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')
    vis = True

    for epoch_num in range(num_epochs):
        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:
                if len(rpn_accuracy_rpn_monitor) == epoch_length and c.verbose:
                    mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor)) / len(rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(
                            mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')

                X, Y, img_data = next(data_gen_train)

                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                R = label_rcnn_roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], c, K.image_dim_ordering(), use_regr=True,
                                                      overlap_thresh=0.7, max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = label_rcnn_roi_helpers.calc_iou(R, img_data, c, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if c.num_rois > 1:
                    if len(pos_samples) < c.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(pos_samples, c.num_rois // 2, replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(neg_samples, c.num_rois - len(selected_pos_samples),
                                                                replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(neg_samples, c.num_rois - len(selected_pos_samples),
                                                                replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                # loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]],
                #                                              [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                # losses[iter_num, 2] = loss_class[1]
                # losses[iter_num, 3] = loss_class[2]
                # losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(iter_num,
                               [('rpn_cls', np.mean(losses[:iter_num, 0])),
                                ('rpn_regr', np.mean(losses[:iter_num, 1]))])
                # progbar.update(iter_num,
                #                [('rpn_cls', np.mean(losses[:iter_num, 0])),
                #                 ('rpn_regr', np.mean(losses[:iter_num, 1])),
                #                 ('detector_cls', np.mean(losses[:iter_num, 2])),
                #                 ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    # loss_class_cls = np.mean(losses[:, 2])
                    # loss_class_regr = np.mean(losses[:, 3])
                    # class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if c.verbose:
                        print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(
                            mean_overlapping_bboxes))
                        # print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        # print('Loss Detector classifier: {}'.format(loss_class_cls))
                        # print('Loss Detector regression: {}'.format(loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() - start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr
                    # curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if c.verbose:
                            print('Total loss decreased from {} to {}, saving weights'.format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_rpn.save_weights(c.model_path)
                        # model_all.save_weights(c.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                continue

    print('Training complete, exiting.')
Example #8
0
class FergusNModel(object):
    def __init__(self, igor):

        now = datetime.now()
        self.run_name = "fergusn_{}mo_{}day_{}hr_{}min".format(
            now.month, now.day, now.hour, now.minute)
        log_location = join(igor.log_dir, self.run_name + ".log")
        self.logger = igor.logger = make_logger(igor, log_location)
        igor.verify_directories()
        self.igor = igor

    @classmethod
    def from_yaml(cls, yamlfile, kwargs=None):
        igor = Igor.from_file(yamlfile)
        igor.prep()
        model = cls(igor)
        model.make(kwargs)
        return model

    @classmethod
    def from_config(cls, config, kwargs=None):
        igor = Igor(config)
        model = cls(igor)
        igor.prep()
        model.make(kwargs)
        return model

    def load_checkpoint_weights(self):
        weight_file = join(self.igor.model_location, self.igor.saving_prefix,
                           self.igor.checkpoint_weights)
        if exists(weight_file):
            self.logger.info("+ Loading checkpoint weights")
            self.model.load_weights(weight_file, by_name=True)
        else:
            self.logger.warning(
                "- Checkpoint weights do not exist; {}".format(weight_file))

    def plot(self):
        filename = join(self.igor.model_location, self.igor.saving_prefix,
                        'model_visualization.png')
        kplot(self.model, to_file=filename)
        self.logger.debug("+ Model visualized at {}".format(filename))

    def make(self, theano_kwargs=None):
        '''Make the model and compile it. 

        Igor's config options control everything.  

        Arg:
            theano_kwargs as dict for debugging theano or submitting something custom
        '''

        if self.igor.embedding_type == "convolutional":
            make_convolutional_embedding(self.igor)
        elif self.igor.embedding_type == "token":
            make_token_embedding(self.igor)
        elif self.igor.embedding_type == "shallowconv":
            make_shallow_convolutional_embedding(self.igor)
        elif self.igor.embedding_type == "minimaltoken":
            make_minimal_token_embedding(self.igor)
        else:
            raise Exception("Incorrect embedding type")

        B = self.igor.batch_size
        spine_input_shape = (B, self.igor.max_num_supertags)
        child_input_shape = (B, 1)
        parent_input_shape = (B, 1)

        E, V = self.igor.word_embedding_size, self.igor.word_vocab_size  # for word embeddings

        repeat_N = self.igor.max_num_supertags  # for lex
        mlp_size = self.igor.mlp_size

        ## dropout parameters
        p_emb = self.igor.p_emb_dropout
        p_W = self.igor.p_W_dropout
        p_U = self.igor.p_U_dropout
        w_decay = self.igor.weight_decay
        p_mlp = self.igor.p_mlp_dropout

        def predict_params():
            return {
                'output_dim': 1,
                'W_regularizer': l2(w_decay),
                'activation': 'relu',
                'b_regularizer': l2(w_decay)
            }

        dspineset_in = Input(batch_shape=spine_input_shape,
                             name='daughter_spineset_in',
                             dtype='int32')
        pspineset_in = Input(batch_shape=spine_input_shape,
                             name='parent_spineset_in',
                             dtype='int32')
        dhead_in = Input(batch_shape=child_input_shape,
                         name='daughter_head_input',
                         dtype='int32')
        phead_in = Input(batch_shape=parent_input_shape,
                         name='parent_head_input',
                         dtype='int32')
        dspine_in = Input(batch_shape=child_input_shape,
                          name='daughter_spine_input',
                          dtype='int32')
        inputs = [dspineset_in, pspineset_in, dhead_in, phead_in, dspine_in]

        ### Layer functions
        ############# Convert the word indices to vectors
        F_embedword = Embedding(input_dim=V,
                                output_dim=E,
                                mask_zero=True,
                                W_regularizer=l2(w_decay),
                                dropout=p_emb)

        if self.igor.saved_embeddings is not None:
            self.logger.info("+ Cached embeddings loaded")
            F_embedword.initial_weights = [self.igor.saved_embeddings]

        ###### Prediction Functions
        ## these functions learn a vector which turns a tensor into a matrix of probabilities

        ### P(Parent supertag | Child, Context)
        F_parent_predict = ProbabilityTensor(
            name='parent_predictions',
            dense_function=Dense(**predict_params()))
        ### P(Leaf supertag)
        F_leaf_predict = ProbabilityTensor(
            name='leaf_predictions', dense_function=Dense(**predict_params()))

        ###### Network functions.
        ##### Input word, correct its dimensions (basically squash in a certain way)
        F_singleword = compose(Fix(), F_embedword)
        ##### Input spine, correct diemnsions, broadcast across 1st dimension
        F_singlespine = compose(RepeatVector(repeat_N), Fix(),
                                self.igor.F_embedspine)
        ##### Concatenate and map to a single space
        F_alignlex = compose(
            RepeatVector(repeat_N), Dropout(p_mlp),
            Dense(mlp_size, activation='relu', name='dense_align_lex'), concat)

        F_alignall = compose(
            Distribute(Dropout(p_mlp), name='distribute_align_all_dropout'),
            Distribute(Dense(mlp_size,
                             activation='relu',
                             name='align_all_dense'),
                       name='distribute_align_all_dense'), concat)
        F_alignleaf = compose(
            Distribute(
                Dropout(p_mlp * 0.66), name='distribute_leaf_dropout'
            ),  ### need a separate oen because the 'concat' is different for the two situations
            Distribute(Dense(mlp_size, activation='relu', name='leaf_dense'),
                       name='distribute_leaf_dense'),
            concat)

        ### embed and form all of the inputs into their components
        ### note: spines == supertags. early word choice, haven't refactored.
        leaf_spines = self.igor.F_embedspine(dspineset_in)
        pspine_context = self.igor.F_embedspine(pspineset_in)
        dspine_single = F_singlespine(dspine_in)

        dhead = F_singleword(dhead_in)
        phead = F_singleword(phead_in)

        ### combine the lexical material
        lexical_context = F_alignlex([dhead, phead])

        #### P(Parent Supertag | Daughter Supertag, Lexical Context)
        ### we know the daughter spine, want to know the parent spine
        ### size is (batch, num_supertags)
        parent_problem = F_alignall(
            [lexical_context, dspine_single, pspine_context])

        ### we don't have the parent, we just have a leaf
        leaf_problem = F_alignleaf([lexical_context, leaf_spines])

        parent_predictions = F_parent_predict(parent_problem)
        leaf_predictions = F_leaf_predict(leaf_problem)
        predictions = [parent_predictions, leaf_predictions]

        theano_kwargs = theano_kwargs or {}
        ## make it quick so i can load in the weights.
        self.model = Model(input=inputs,
                           output=predictions,
                           preloaded_data=self.igor.preloaded_data,
                           **theano_kwargs)

        #mask_cache = traverse_nodes(parent_prediction)
        #desired_masks = ['merge_3.in.mask.0']
        #self.p_tensor = K.function(inputs+[K.learning_phase()], [parent_predictions, F_parent_predict.inbound_nodes[0].input_masks[0]])

        if self.igor.from_checkpoint:
            self.load_checkpoint_weights()
        elif not self.igor.in_training:
            raise Exception("No point in running this without trained weights")

        if not self.igor.in_training:
            expanded_children = RepeatVector(repeat_N, axis=2)(leaf_spines)
            expanded_parent = RepeatVector(repeat_N, axis=1)(pspine_context)
            expanded_lex = RepeatVector(repeat_N, axis=1)(
                lexical_context
            )  # axis here is arbitary; its repeating on 1 and 2, but already repeated once
            huge_tensor = concat(
                [expanded_lex, expanded_children, expanded_parent])
            densely_aligned = LastDimDistribute(
                F_alignall.get(1).layer)(huge_tensor)
            output_predictions = Distribute(
                F_parent_predict, force_reshape=True)(densely_aligned)

            primary_inputs = [phead_in, dhead_in, pspineset_in, dspineset_in]
            leaf_inputs = [phead_in, dhead_in, dspineset_in]

            self.logger.info("+ Compiling prediction functions")
            self.inner_func = K.Function(primary_inputs + [K.learning_phase()],
                                         output_predictions)
            self.leaf_func = K.Function(leaf_inputs + [K.learning_phase()],
                                        leaf_predictions)
            try:
                self.get_ptensor = K.function(
                    primary_inputs + [K.learning_phase()], [
                        output_predictions,
                    ])
            except:
                import pdb
                pdb.set_trace()
        else:

            optimizer = Adam(self.igor.LR,
                             clipnorm=self.igor.max_grad_norm,
                             clipvalue=self.igor.grad_clip_threshold)

            theano_kwargs = theano_kwargs or {}
            self.model.compile(loss="categorical_crossentropy",
                               optimizer=optimizer,
                               metrics=['accuracy'],
                               **theano_kwargs)

        #self.model.save("here.h5")

    def likelihood_function(self, inputs):
        if self.igor.in_training:
            raise Exception("Not in testing mode; please fix the config file")
        return self.inner_func(tuple(inputs) + (0., ))

    def leaf_function(self, inputs):
        if self.igor.in_training:
            raise Exception("Not in testing mode; please fix the config file")
        return self.leaf_func(tuple(inputs) + (0., ))

    def train(self):
        replacers = {
            "daughter_predictions": "child",
            "parent_predictions": "parent",
            "leaf_predictions": "leaf"
        }
        train_data = self.igor.train_gen(forever=True)
        dev_data = self.igor.dev_gen(forever=True)
        N = self.igor.num_train_samples
        E = self.igor.num_epochs
        # generator, samplers per epoch, number epochs
        callbacks = [ProgbarV2(3, 10, replacers=replacers)]
        checkpoint_fp = join(self.igor.model_location, self.igor.saving_prefix,
                             self.igor.checkpoint_weights)
        self.logger.info("+ Model Checkpoint: {}".format(checkpoint_fp))
        callbacks += [
            ModelCheckpoint(filepath=checkpoint_fp,
                            verbose=1,
                            save_best_only=True)
        ]
        callbacks += [LearningRateScheduler(lambda epoch: self.igor.LR * 0.9)]
        csv_location = join(self.igor.log_dir, self.run_name + ".csv")
        callbacks += [CSVLogger(csv_location)]
        self.model.fit_generator(generator=train_data,
                                 samples_per_epoch=N,
                                 nb_epoch=E,
                                 callbacks=callbacks,
                                 verbose=1,
                                 validation_data=dev_data,
                                 nb_val_samples=self.igor.num_dev_samples)

    def debug(self):
        dev_data = self.igor.dev_gen(forever=False)
        X, Y = next(dev_data)
        self.model.predict_on_batch(X)
        #self.model.evaluate_generator(dev_data, self.igor.num_dev_samples)

    def profile(self, num_iterations=1):
        train_data = self.igor.train_gen(forever=True)
        dev_data = self.igor.dev_gen(forever=True)
        # generator, samplers per epoch, number epochs
        callbacks = [ProgbarV2(1, 10)]
        self.logger.debug("+ Beginning the generator")
        self.model.fit_generator(generator=train_data,
                                 samples_per_epoch=self.igor.batch_size * 10,
                                 nb_epoch=num_iterations,
                                 callbacks=callbacks,
                                 verbose=1,
                                 validation_data=dev_data,
                                 nb_val_samples=self.igor.batch_size)
        self.logger.debug(
            "+ Calling theano's pydot print.. this might take a while")
        theano.printing.pydotprint(self.model.train_function.function,
                                   outfile='theano_graph.png',
                                   var_with_name_simple=True,
                                   with_ids=True)
        self.logger.debug("+ Calling keras' print.. this might take a while")
        self.plot("keras_graph.png")
Example #9
0
        save_best_only=False,
        save_weights_only=True,
        mode='auto',
        period=1)
    model_clstm.fit(X_train_2nd,
                    y_train_2nd,
                    batch_size=4,
                    validation_data=(X_test, y_test),
                    shuffle=True,
                    epochs=50,
                    callbacks=[mc])

    # Save model information in yaml and weight
    open('/your/path/checkpoint-clstm/model_audio_clstm.yaml',
         'w').write(model_clstm.to_yaml())
    proba_clstm = model_clstm.predict_on_batch(X_test)

    # Case 2: Bidirectional LSTM model
    model_Bilstm = Sequential()
    model_Bilstm.add(LSTM(577, return_sequences=True, input_shape=(1, 577)))
    model_Bilstm.add(Dropout(0.8))
    model_Bilstm.add(Bidirectional(LSTM(577, return_sequences=True)))
    model_Bilstm.add(Dropout(0.8))
    model_Bilstm.add(Bidirectional(LSTM(577)))
    model_Bilstm.add(Dropout(0.8))
    model_Bilstm.add(Dense(7, activation='softmax'))
    model_Bilstm.summary()

    sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)
    model_Bilstm.compile(optimizer=sgd,
                         loss='categorical_crossentropy',
Example #10
0
def main(base_path, debug, train_mode, test_mode):

    train_path = os.path.join(base_path, 'train_annotation.txt')

    num_rois = 4
    horizontal_flips = True
    vertical_flips = True
    rot_90 = True

    output_weight_path = os.path.join(base_path, 'model/model_frcnn_vgg.hdf5')
    record_path = os.path.join(base_path, 'model/record.csv')
    # Record data (used to save the losses, classification accuracy and mean average precision)
    base_weight_path = os.path.join(
        base_path, 'model/vgg16_weights_tf_dim_ordering_tf_kernels.h5')
    config_output_filename = os.path.join(base_path, 'model_vgg_config.pickle')

    # Create the config
    C = config.Config()
    C.use_horizontal_flips = horizontal_flips
    C.use_vertical_flips = vertical_flips
    C.rot_90 = rot_90
    C.record_path = record_path
    C.model_path = output_weight_path
    C.num_rois = num_rois
    C.base_net_weights = base_weight_path

    # --------------------------------------------------------#
    # This step will spend some time to load the data        #
    # --------------------------------------------------------#
    st = time.time()
    train_imgs, classes_count, class_mapping = extract_data.get_data(
        train_path, base_path)
    print()
    print('Spend %0.2f mins to load the data' % ((time.time() - st) / 60))
    # --------------------------------------------------------#

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)
    # e.g.
    #    classes_count: {'Car': 2383, 'Mobile phone': 1108, 'Person': 3745, 'bg': 0}
    #    class_mapping: {'Person': 0, 'Car': 1, 'Mobile phone': 2, 'bg': 3}
    C.class_mapping = class_mapping

    # Save the configuration
    with open(config_output_filename, 'wb') as config_f:
        pickle.dump(C, config_f)
        print(
            'Config has been written to {}, and can be loaded when '
            'testing to ensure correct results'.format(config_output_filename))

    # Shuffle the images with seed
    random.seed(1)
    random.shuffle(train_imgs)
    print('Num train samples (images) {}'.format(len(train_imgs)))

    # Get train data generator which generate X, Y, image_data
    data_gen_train = get_anchor_gt(train_imgs,
                                   C,
                                   get_img_output_length,
                                   mode='train')

    if debug:
        X, Y, image_data, debug_img, debug_num_pos = next(data_gen_train)

        print('Original image: height=%d width=%d' %
              (image_data['height'], image_data['width']))
        print('Resized image:  height=%d width=%d C.im_size=%d' %
              (X.shape[1], X.shape[2], C.im_size))
        print('Feature map size: height=%d width=%d C.rpn_stride=%d' %
              (Y[0].shape[1], Y[0].shape[2], C.rpn_stride))
        print(X.shape)
        print(str(len(Y)) + " includes 'y_rpn_cls' and 'y_rpn_regr'")
        print('Shape of y_rpn_cls {}'.format(Y[0].shape))
        print('Shape of y_rpn_regr {}'.format(Y[1].shape))
        print(image_data)

        print('Number of positive anchors for this image: %d' %
              (debug_num_pos))
        if debug_num_pos == 0:
            gt_x1, gt_x2 = image_data['bboxes'][0]['x1'] * (
                X.shape[2] /
                image_data['height']), image_data['bboxes'][0]['x2'] * (
                    X.shape[2] / image_data['height'])

            gt_y1, gt_y2 = image_data['bboxes'][0]['y1'] * (
                X.shape[1] /
                image_data['width']), image_data['bboxes'][0]['y2'] * (
                    X.shape[1] / image_data['width'])

            gt_x1, gt_y1, gt_x2, gt_y2 = int(gt_x1), int(gt_y1), int(
                gt_x2), int(gt_y2)

            img = debug_img.copy()
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            color = (0, 255, 0)
            cv2.putText(img, 'gt bbox', (gt_x1, gt_y1 - 5),
                        cv2.FONT_HERSHEY_DUPLEX, 0.7, color, 1)
            cv2.rectangle(img, (gt_x1, gt_y1), (gt_x2, gt_y2), color, 2)
            cv2.circle(img, (int((gt_x1 + gt_x2) / 2), int(
                (gt_y1 + gt_y2) / 2)), 3, color, -1)

            plt.grid()
            plt.imshow(img)
            plt.show()
        else:
            cls = Y[0][0]
            pos_cls = np.where(cls == 1)
            print("==> Positive classes", pos_cls)
            regr = Y[1][0]
            pos_regr = np.where(regr == 1)
            print("==> Positive regressions", pos_regr)
            print('--------------------------------------------------------')
            print('y_rpn_cls for possible pos anchor:')
            print(cls[pos_cls[0][0], pos_cls[1][0], :])
            print('--------------------------------------------------------')
            print('y_rpn_regr for positive anchor:')
            print(regr[pos_regr[0][0], pos_regr[1][0], :])
            print('--------------------------------------------------------')

            gt_x1, gt_x2 = image_data['bboxes'][0]['x1'] * (
                X.shape[2] /
                image_data['width']), image_data['bboxes'][0]['x2'] * (
                    X.shape[2] / image_data['width'])

            gt_y1, gt_y2 = image_data['bboxes'][0]['y1'] * (
                X.shape[1] /
                image_data['height']), image_data['bboxes'][0]['y2'] * (
                    X.shape[1] / image_data['height'])

            gt_x1, gt_y1, gt_x2, gt_y2 = int(gt_x1), int(gt_y1), int(
                gt_x2), int(gt_y2)

            img = debug_img.copy()
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            color = (0, 255, 0)
            #   cv2.putText(img, 'gt bbox', (gt_x1, gt_y1-5), cv2.FONT_HERSHEY_DUPLEX, 0.7, color, 1)
            cv2.rectangle(img, (gt_x1, gt_y1), (gt_x2, gt_y2), color, 2)
            cv2.circle(img, (int((gt_x1 + gt_x2) / 2), int(
                (gt_y1 + gt_y2) / 2)), 3, color, -1)

            # Add text
            textLabel = 'gt bbox'
            (retval, baseLine) = cv2.getTextSize(textLabel,
                                                 cv2.FONT_HERSHEY_COMPLEX, 0.5,
                                                 1)
            textOrg = (gt_x1, gt_y1 + 5)
            cv2.rectangle(
                img, (textOrg[0] - 5, textOrg[1] + baseLine - 5),
                (textOrg[0] + retval[0] + 5, textOrg[1] - retval[1] - 5),
                (0, 0, 0), 2)
            cv2.rectangle(
                img, (textOrg[0] - 5, textOrg[1] + baseLine - 5),
                (textOrg[0] + retval[0] + 5, textOrg[1] - retval[1] - 5),
                (255, 255, 255), -1)
            cv2.putText(img, textLabel, textOrg, cv2.FONT_HERSHEY_DUPLEX, 0.5,
                        (0, 0, 0), 1)

            # Draw positive anchors according to the y_rpn_regr
            for i in range(debug_num_pos):
                color = (100 + i * (155 / 4), 0, 100 + i * (155 / 4))

                idx = pos_regr[2][i * 4] / 4

                anchor_size = C.anchor_box_scales[int(idx / 3)]
                anchor_ratio = C.anchor_box_ratios[2 - int((idx + 1) % 3)]

                center = (pos_regr[1][i * 4] * C.rpn_stride,
                          pos_regr[0][i * 4] * C.rpn_stride)
                print('Center position of positive anchor: ', center)
                # I have the center, I have the anchor size and ratio...
                cv2.circle(img, center, 3, color, -1)
                anc_w, anc_h = anchor_size * anchor_ratio[
                    0], anchor_size * anchor_ratio[1]
                cv2.rectangle(
                    img,
                    (center[0] - int(anc_w / 2), center[1] - int(anc_h / 2)),
                    (center[0] + int(anc_w / 2), center[1] + int(anc_h / 2)),
                    color, 2)
                cv2.putText(img, 'anchor ' + str(i + 1),
                            (center[0] - int(anc_w / 2),
                             center[1] - int(anc_h / 2) - 5),
                            cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1)

        print('Green bboxes is ground-truth bbox. Others are positive anchors')
        plt.figure(figsize=(8, 8))
        plt.grid()
        plt.imshow(img)
        plt.show()

    # Build the model
    input_shape_img = (None, None, 3)
    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (VGG here, can be Resnet50, Inception, etc)
    shared_layers = base_cnn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
    # rpn built on top of base model
    # return -> classifer with sigmoid,
    # -> regression, ksize = (1, 1)
    # -> base_model
    rpn = rpn_layer(shared_layers, num_anchors)

    # Implements ROI pooling under the hood,
    # returns classifier and regression for
    # proposed regions
    classifier = roi_util.classifier_layer(shared_layers,
                                           roi_input,
                                           C.num_rois,
                                           nb_classes=len(classes_count))

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier,
    # used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    if not os.path.isfile(C.model_path):
        # If this is the begin of the training, load the
        # pre-traind base network such as vgg-16
        try:
            print('This is the first time of your training')
            print('loading weights from {}'.format(C.base_net_weights))
            model_rpn.load_weights(C.base_net_weights, by_name=True)
            model_classifier.load_weights(C.base_net_weights, by_name=True)
        except:
            print(
                'Could not load pretrained model weights. '
                'Weights can be found in the keras application folder'
                'https://github.com/fchollet/keras/tree/master/keras/applications'
            )

        # Create the record.csv file to record losses, acc and mAP
        record_df = pd.DataFrame(columns=[
            'mean_overlapping_bboxes', 'class_acc', 'loss_rpn_cls',
            'loss_rpn_regr', 'loss_class_cls', 'loss_class_regr', 'curr_loss',
            'elapsed_time', 'mAP'
        ])
    else:
        # If this is a continued training, load the trained model from before
        # Resume training if from existing point:
        print('Continue training based on previous trained model')
        print('Loading weights from {}'.format(C.model_path))
        model_rpn.load_weights(C.model_path, by_name=True)
        model_classifier.load_weights(C.model_path, by_name=True)

        # Load the records
        record_df = pd.read_csv(record_path)

        r_mean_overlapping_bboxes = record_df['mean_overlapping_bboxes']
        r_class_acc = record_df['class_acc']
        r_loss_rpn_cls = record_df['loss_rpn_cls']
        r_loss_rpn_regr = record_df['loss_rpn_regr']
        r_loss_class_cls = record_df['loss_class_cls']
        r_loss_class_regr = record_df['loss_class_regr']
        r_curr_loss = record_df['curr_loss']
        r_elapsed_time = record_df['elapsed_time']
        r_mAP = record_df['mAP']

        print('Already train %dK batches' % (len(record_df)))

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(
        optimizer=optimizer,
        loss=[rpn_loss_cls(num_anchors),
              rpn_loss_regr(num_anchors)])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[class_loss_cls,
              class_loss_regr(len(classes_count) - 1)],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    # Training setting
    total_epochs = len(record_df)
    r_epochs = len(record_df)

    epoch_length = 1000
    num_epochs = 40
    iter_num = 0

    total_epochs += num_epochs

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []

    if len(record_df) == 0:
        best_loss = np.Inf
    else:
        best_loss = np.min(r_curr_loss)

    start_time = time.time()
    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(r_epochs + 1, total_epochs))

        r_epochs += 1

        while True:
            try:

                if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {}'
                        ' for {} previous iterations'.format(
                            mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap the ground truth boxes.'
                            'Check RPN settings or keep training.')

                # Generate X (x_img) and label Y ([y_rpn_cls, y_rpn_regr])
                X, Y, img_data, debug_img, debug_num_pos = next(data_gen_train)

                # Train rpn model and get loss value [_, loss_rpn_cls, loss_rpn_regr]
                loss_rpn = model_rpn.train_on_batch(X, Y)

                # Get predicted rpn from rpn model [rpn_cls, rpn_regr]
                P_rpn = model_rpn.predict_on_batch(X)

                # R: bboxes (shape=(300,4))
                # Convert rpn layer to roi bboxes
                R = roi_util.rpn_to_roi(P_rpn[0],
                                        P_rpn[1],
                                        C,
                                        K.image_dim_ordering(),
                                        use_regr=True,
                                        overlap_thresh=0.7,
                                        max_boxes=300)

                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                # X2: bboxes that iou > C.classifier_min_overlap for all gt bboxes
                # in 300 non_max_suppression bboxes
                # Y1: one hot code for bboxes from above => x_roi (X)
                # Y2: corresponding labels and corresponding gt bboxes
                X2, Y1, Y2, IouS = calc_iou(R, img_data, C, class_mapping)

                # If X2 is None means there are no matching bboxes
                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                # Find out the positive anchors and negative anchors
                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if C.num_rois > 1:
                    # If number of positive anchors is larger than 4//2 = 2, randomly choose 2 pos samples
                    if len(pos_samples) < C.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, C.num_rois // 2,
                            replace=False).tolist()

                    # Randomly choose (num_rois - num_pos) neg samples
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            C.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            C.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    # Save all the pos and neg samples in sel_samples
                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                # training_data: [X, X2[:, sel_samples, :]]
                # labels: [Y1[:, sel_samples, :], Y2[:, sel_samples, :]]
                #  X                     => img_data resized image
                #  X2[:, sel_samples, :] => num_rois (4 in here) bboxes which contains selected neg and pos
                #  Y1[:, sel_samples, :] => one hot encode for num_rois bboxes which contains selected neg and pos
                #  Y2[:, sel_samples, :] => labels and gt bboxes for num_rois bboxes
                # which contains selected neg and pos
                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(iter_num,
                               [('rpn_cls', np.mean(losses[:iter_num, 0])),
                                ('rpn_regr', np.mean(losses[:iter_num, 1])),
                                ('final_cls', np.mean(losses[:iter_num, 2])),
                                ('final_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if C.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping'
                            ' ground truth boxes: {}'.format(
                                mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Total loss: {}'.format(loss_rpn_cls +
                                                      loss_rpn_regr +
                                                      loss_class_cls +
                                                      loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))
                        elapsed_time = (time.time() - start_time) / 60

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if C.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(C.model_path)

                    new_row = {
                        'mean_overlapping_bboxes':
                        round(mean_overlapping_bboxes, 3),
                        'class_acc':
                        round(class_acc, 3),
                        'loss_rpn_cls':
                        round(loss_rpn_cls, 3),
                        'loss_rpn_regr':
                        round(loss_rpn_regr, 3),
                        'loss_class_cls':
                        round(loss_class_cls, 3),
                        'loss_class_regr':
                        round(loss_class_regr, 3),
                        'curr_loss':
                        round(curr_loss, 3),
                        'elapsed_time':
                        round(elapsed_time, 3),
                        'mAP':
                        0
                    }

                    record_df = record_df.append(new_row, ignore_index=True)
                    record_df.to_csv(record_path, index=0)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                continue

    print('Training complete, exiting.')

    if test_mode:
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.plot(np.arange(0, r_epochs), record_df['mean_overlapping_bboxes'],
                 'r')
        plt.title('mean_overlapping_bboxes')
        plt.subplot(1, 2, 2)
        plt.plot(np.arange(0, r_epochs), record_df['class_acc'], 'r')
        plt.title('class_acc')

        plt.show()

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.plot(np.arange(0, r_epochs), record_df['loss_rpn_cls'], 'r')
        plt.title('loss_rpn_cls')
        plt.subplot(1, 2, 2)
        plt.plot(np.arange(0, r_epochs), record_df['loss_rpn_regr'], 'r')
        plt.title('loss_rpn_regr')
        plt.show()

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.plot(np.arange(0, r_epochs), record_df['loss_class_cls'], 'r')
        plt.title('loss_class_cls')
        plt.subplot(1, 2, 2)
        plt.plot(np.arange(0, r_epochs), record_df['loss_class_regr'], 'r')
        plt.title('loss_class_regr')
        plt.show()

        plt.plot(np.arange(0, r_epochs), record_df['curr_loss'], 'r')
        plt.title('total_loss')
        plt.show()
Example #11
0
def train_rpn(model_file=None):

    parser = OptionParser()
    parser.add_option("--train_path",
                      dest="train_path",
                      help="Path to training data.",
                      default='/Users/jie/projects/PanelSeg/ExpRcnn/train.txt')
    parser.add_option("--val_path",
                      dest="val_path",
                      help="Path to validation data.",
                      default='/Users/jie/projects/PanelSeg/ExpRcnn/eval.txt')
    parser.add_option("--num_rois",
                      type="int",
                      dest="num_rois",
                      help="Number of RoIs to process at once.",
                      default=32)
    parser.add_option("--network",
                      dest="network",
                      help="Base network to use. Supports nn_cnn_3_layer.",
                      default='nn_cnn_3_layer')
    parser.add_option("--num_epochs",
                      type="int",
                      dest="num_epochs",
                      help="Number of epochs.",
                      default=100)
    parser.add_option("--output_weight_path",
                      dest="output_weight_path",
                      help="Output path for weights.",
                      default='./model_frcnn.hdf5')
    parser.add_option(
        "--input_weight_path",
        dest="input_weight_path",
        default=
        '/Users/jie/projects/PanelSeg/ExpRcnn/models/model_rpn_3_layer_color-0.0293.hdf5'
    )

    (options, args) = parser.parse_args()

    # set configuration
    c = Config.Config()

    c.model_path = options.output_weight_path
    c.num_rois = int(options.num_rois)

    import nn_cnn_3_layer as nn

    c.base_net_weights = options.input_weight_path

    val_imgs, val_classes_count = get_label_rpn_data(options.val_path)
    train_imgs, train_classes_count = get_label_rpn_data(options.train_path)

    classes_count = {
        k: train_classes_count.get(k, 0) + val_classes_count.get(k, 0)
        for k in set(train_classes_count) | set(val_classes_count)
    }
    class_mapping = LABEL_CLASS_MAPPING

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    c.class_mapping = class_mapping

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))

    config_output_filename = 'config.pickle'

    with open(config_output_filename, 'wb') as config_f:
        pickle.dump(c, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(config_output_filename))

    random.shuffle(train_imgs)
    random.shuffle(val_imgs)

    num_imgs = len(train_imgs) + len(val_imgs)

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = label_rcnn_data_generators.get_anchor_gt(
        train_imgs,
        classes_count,
        c,
        nn.nn_get_img_output_length,
        mode='train')
    data_gen_val = label_rcnn_data_generators.get_anchor_gt(
        val_imgs, classes_count, c, nn.nn_get_img_output_length, mode='val')

    input_shape_img = (None, None, 3)
    img_input = Input(shape=input_shape_img)
    # roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(c.anchor_box_scales) * len(c.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    # classifier = nn.classifier(shared_layers, roi_input, c.num_rois, nb_classes=len(classes_count), trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    # model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    # model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    print('loading weights from {}'.format(c.base_net_weights))
    model_rpn.load_weights(c.base_net_weights, by_name=True)
    # model_classifier.load_weights(c.base_net_weights, by_name=True)
    model_rpn.summary()

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(
        optimizer=optimizer,
        loss=[nn.rpn_loss_cls(num_anchors),
              nn.rpn_loss_regr(num_anchors)])
    # model_classifier.compile(optimizer=optimizer_classifier,
    #                          loss=[nn.class_loss_cls, nn.class_loss_regr(len(classes_count) - 1)],
    #                          metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    # model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 500
    num_epochs = int(options.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')
    vis = True

    for epoch_num in range(num_epochs):
        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:
                if len(rpn_accuracy_rpn_monitor) == epoch_length and c.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.'
                        )

                X, Y, img_data = next(data_gen_train)

                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                R = label_rcnn_roi_helpers.rpn_to_roi(P_rpn[0],
                                                      P_rpn[1],
                                                      c,
                                                      K.image_dim_ordering(),
                                                      use_regr=True,
                                                      overlap_thresh=0.7,
                                                      max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = label_rcnn_roi_helpers.calc_iou(
                    R, img_data, c, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if c.num_rois > 1:
                    if len(pos_samples) < c.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, c.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            c.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            c.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                # loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]],
                #                                              [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                # losses[iter_num, 2] = loss_class[1]
                # losses[iter_num, 3] = loss_class[2]
                # losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(iter_num,
                               [('rpn_cls', np.mean(losses[:iter_num, 0])),
                                ('rpn_regr', np.mean(losses[:iter_num, 1]))])
                # progbar.update(iter_num,
                #                [('rpn_cls', np.mean(losses[:iter_num, 0])),
                #                 ('rpn_regr', np.mean(losses[:iter_num, 1])),
                #                 ('detector_cls', np.mean(losses[:iter_num, 2])),
                #                 ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    # loss_class_cls = np.mean(losses[:, 2])
                    # loss_class_regr = np.mean(losses[:, 3])
                    # class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if c.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        # print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        # print('Loss Detector classifier: {}'.format(loss_class_cls))
                        # print('Loss Detector regression: {}'.format(loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr
                    # curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if c.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_rpn.save_weights(c.model_path)
                        # model_all.save_weights(c.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                continue

    print('Training complete, exiting.')