예제 #1
0
파일: train.py 프로젝트: wkentaro/apc-od
    def main_loop(self, n_epoch=10, save_interval=None, save_encoded=True):
        save_interval = save_interval or (n_epoch // 10)
        train_data = get_raw(which_set='train')
        test_data = get_raw(which_set='test')
        N_train = len(train_data.filenames)
        N_test = len(test_data.filenames)
        logging.info('converting dataset to x and t data')
        train_x, train_t = self.dataset_to_xt_data(train_data, self.crop_roi)
        test_x, test_t = self.dataset_to_xt_data(test_data, self.crop_roi)
        for epoch in xrange(0, n_epoch):
            # train
            sum_loss, sum_accuracy, _, _ = \
                self.batch_loop(train_x, train_t, train=True)
            for loss_id, sl in sorted(sum_loss.items()):
                mean_loss = sl / N_train
                msg = 'epoch:{:02d}; train mean loss{}={};'\
                    .format(epoch, loss_id, mean_loss)
                if self.is_supervised:
                    mean_accuracy = sum_accuracy / N_train
                    msg += ' accuracy={};'.format(mean_accuracy)
                logging.info(msg)
                print(msg)
            # test
            sum_loss, sum_accuracy, x_batch, y_batch = \
                self.batch_loop(test_x, test_t, train=False)
            for loss_id, sl in sorted(sum_loss.items()):
                mean_loss = sl / N_test
                msg = 'epoch:{:02d}; test mean loss{}={};'\
                    .format(epoch, loss_id, mean_loss)
                if self.is_supervised:
                    mean_accuracy = sum_accuracy / N_test
                    msg += ' accuracy={};'.format(mean_accuracy)
                logging.info(msg)
                print(msg)
            # save model and input/encoded/decoded
            if epoch % save_interval == (save_interval - 1):
                print('epoch:{:02d}; saving'.format(epoch))
                # save model
                model_path = osp.join(
                    self.log_dir,
                    '{name}_model_{epoch}.h5'.format(
                        name=self.model_name, epoch=epoch))
                serializers.save_hdf5(model_path, self.model)
                # save optimizer
                for i, opt in enumerate(self.optimizers):
                    opt_path = osp.join(
                        self.log_dir,
                        '{name}_optimizer_{epoch}_{i}.h5'.format(
                            name=self.model_name, epoch=epoch, i=i))
                    serializers.save_hdf5(opt_path, opt)
                # save x_data
                x_path = osp.join(self.log_dir, 'x_{}.pkl'.format(epoch))
                with open(x_path, 'wb') as f:
                    pickle.dump(x_batch, f)  # save x
                if not self.is_supervised:
                    x_hat_path = osp.join(self.log_dir,
                                          'x_hat_{}.pkl'.format(epoch))
                    with open(x_hat_path, 'wb') as f:
                        pickle.dump(y_batch, f)  # save x_hat
                    tile_ae_inout(
                        x_batch, y_batch,
                        osp.join(self.log_dir, 'X_{}.jpg'.format(epoch)))
                if save_encoded:
                    x = Variable(cuda.to_gpu(x_batch), volatile=True)
                    z = self.model.encode(x)
                    tile_ae_encoded(
                        cuda.to_cpu(z.data),
                        osp.join(self.log_dir,
                                 'x_encoded_{}.jpg'.format(epoch)))

        for i in xrange(len(self.optimizers)):
            draw_loss_curve(
                loss_id=i,
                logfile=self.log_file,
                outfile=osp.join(self.log_dir, 'loss_curve{}.jpg'.format(i)),
                no_acc=not self.is_supervised,
            )