コード例 #1
0
    def train(self):

        self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
        if self.ckpt_manager.latest_checkpoint:
            print("Restored from {}".format(
                self.ckpt_manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        dg = DataGenerator(task_version=self.task_version, action='train')
        validation_batch_num = self.batch_num // 10

        over_all_performace = []

        print('Start to train')
        print('#' * 20)

        for epoch_i in range(self.epoch_num):
            print('Epoch ' + str(epoch_i))

            # Train
            #
            train_loss_all = 0
            for batch_i in range(self.batch_num):
                decs, masks, inputs, outputs = next(dg)
                with tf.GradientTape() as tape:
                    logits, _ = self.ei_rnn(inputs, [self.init_state],
                                            training=True)
                    logits = tf.transpose(logits, perm=[0, 2, 1])
                    train_loss = self.loss_fun(outputs, logits, masks)
                    train_loss += sum(self.ei_rnn.losses)
                    train_loss_all += train_loss.numpy()

                all_weights = self.ei_rnn.trainable_weights + [self.init_state]
                grads = tape.gradient(train_loss, all_weights)
                grads, _ = tf.clip_by_global_norm(grads, self.grad_clip)
                self.optimizer.apply_gradients(
                    zip(grads, self.ei_rnn.trainable_weights))

            train_loss_all = train_loss_all / self.batch_num
            print('train loss:', train_loss_all)

            with self.train_summary_writer.as_default():
                tf.summary.scalar('loss', train_loss_all, step=epoch_i)

            # Validation
            #
            validation_loss_all = 0
            validation_acc_all = 0

            for v_batch_i in range(validation_batch_num):
                _, v_masks, v_inputs, v_outputs = dg.get_valid_test_datasets(
                )  # dg.get_valid_test_datasets()
                v_logits, _ = self.ei_rnn(v_inputs, [self.init_state])

                acc_v_logits = v_logits.numpy()
                acc_v_outputs = tf.transpose(v_outputs, perm=[0, 2, 1]).numpy()
                validation_acc = self.get_accuracy(acc_v_logits, acc_v_outputs,
                                                   v_masks)
                validation_acc_all += validation_acc

                v_logits = tf.transpose(v_logits, perm=[0, 2, 1])
                validation_loss = self.loss_fun(v_outputs, v_logits, v_masks)
                validation_loss += sum(self.ei_rnn.losses)
                validation_loss_all += validation_loss.numpy()

            validation_loss_all = validation_loss_all / validation_batch_num
            validation_acc_all = validation_acc_all / validation_batch_num
            over_all_performace.append(validation_acc_all)

            print('validation loss:', validation_loss_all)
            print('validation acc:', validation_acc_all)

            with self.validation_summary_writer.as_default():
                tf.summary.scalar('loss', validation_loss_all, step=epoch_i)
                tf.summary.scalar('acc', validation_acc_all, step=epoch_i)

                cm_image = plot.plot_confusion_matrix(self.get_w_rec_m())
                tf.summary.image('M_rec', cm_image, step=epoch_i)

                win_image = plot.plot_confusion_matrix(
                    funs.rectify(
                        self.rnn_cell.W_in.numpy()[:, :int(UNITS_SIZE *
                                                           EI_RATIO)]), False)
                tf.summary.image('M_in', win_image, step=epoch_i)

                wout_image = plot.plot_confusion_matrix(
                    self.get_w_out_m()[:, :int(UNITS_SIZE * EI_RATIO)], False)
                tf.summary.image('M_out', wout_image, step=epoch_i)

                print('spr: ', funs.spectral_radius(self.get_w_rec_m().T))

            if epoch_i > PERFORMANCE_CHECK_REGION and \
                    np.mean(over_all_performace[-PERFORMANCE_CHECK_REGION:]) > PERFORMANCE_LEVEL:
                print(
                    'Overall performance level is satisfied, training is terminated\n'
                )
                break

            # Save Model
            self.ckpt.step.assign_add(1)
            self.ckpt_manager.save()

            # self.reset_all_weights() # todo: may uncomment
            # print('Remove all weights below ' + str(SGD_p['mini_w_threshold']))
            # print('\n')

        print('Training is done')
        print('#' * 20)
        print('\n')
        # Test
        #
        self.test()
コード例 #2
0
    def test(self, test_batch_num=50):
        print('Start to test')
        print('#' * 20)
        self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
        if self.ckpt_manager.latest_checkpoint:
            print("Restored from {}".format(
                self.ckpt_manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        dg = DataGenerator(task_version=self.task_version,
                           action='test')  # todo: should be test for action
        psycollection = {'coh': [], 'perc': []}

        for batch_index in range(20):
            descs, test_masks, test_inputs, test_outputs = dg.get_valid_test_datasets(
            )

            test_logits, _ = self.ei_rnn(test_inputs, [self.init_state],
                                         training=False)

            acc_test_logits = test_logits.numpy()
            acc_test_outputs = tf.transpose(test_outputs, perm=[0, 2,
                                                                1]).numpy()
            test_acc = self.get_accuracy(acc_test_logits, acc_test_outputs,
                                         test_masks)

            test_logits = tf.transpose(test_logits, perm=[0, 2, 1])
            test_loss = self.loss_fun(test_outputs, test_logits, test_masks)
            test_loss += sum(self.ei_rnn.losses)

            print('test loss:', test_loss.numpy())
            print('test acc:', test_acc)

            tmp_data = self.get_psychometric_data(descs, test_logits.numpy())

            psycollection['coh'] += tmp_data['coh']
            psycollection['perc'] += tmp_data['perc']

            with self.test_summary_writer.as_default():
                tf.summary.scalar('loss', test_loss, step=batch_index)
                tf.summary.scalar('acc', test_acc, step=batch_index)

                curve_image = plot.plot_dots(psycollection['coh'],
                                             psycollection['perc'])
                tf.summary.image('psycollection',
                                 curve_image,
                                 step=batch_index)

                cm_image = plot.plot_confusion_matrix(self.get_w_rec_m())
                tf.summary.image('M_rec', cm_image, step=batch_index)

                win_image = plot.plot_confusion_matrix(
                    funs.rectify(
                        self.rnn_cell.W_in.numpy()[:, :int(UNITS_SIZE *
                                                           EI_RATIO)]), False)
                tf.summary.image('M_in', win_image, step=batch_index)

                wout_image = plot.plot_confusion_matrix(
                    self.get_w_out_m()[:, :int(UNITS_SIZE * EI_RATIO)], False)
                tf.summary.image('M_out', wout_image, step=batch_index)