Ejemplo n.º 1
0
 def evaluate(self, n_viz=9):
     iter_val = copy.copy(self.iter_val)
     self.model.train = False
     logs = []
     vizs = []
     dataset = iter_val.dataset
     interval = len(dataset) // n_viz
     desc = 'eval [iter=%d]' % self.iteration
     for batch in tqdm.tqdm(iter_val, desc=desc, total=len(dataset),
                            ncols=80, leave=False):
         in_vars = utils.batch_to_vars(
             batch, device=self.device, volatile=True)
         self.model(*in_vars)
         logs.append(self.model.log)
         if iter_val.current_position % interval == 0 and \
                 len(vizs) < n_viz:
             img = dataset.datum_to_img(self.model.data[0])
             viz = utils.visualize_segmentation(
                 self.model.lbl_pred[0], self.model.lbl_true[0], img,
                 n_class=self.model.n_class)
             vizs.append(viz)
     # save visualization
     out_viz = osp.join(self.out, 'viz_eval', 'iter%d.jpg' % self.iteration)
     if not osp.exists(osp.dirname(out_viz)):
         os.makedirs(osp.dirname(out_viz))
     viz = fcn.utils.get_tile_image(vizs)
     skimage.io.imsave(out_viz, viz)
     # generate log
     log = pd.DataFrame(logs).mean(axis=0).to_dict()
     log = {'validation/%s' % k: v for k, v in log.items()}
     # finalize
     self.model.train = True
     return log
Ejemplo n.º 2
0
    def train(self):
        for iteration, batch in tqdm.tqdm(enumerate(self.iter_train),
                                          desc='train',
                                          total=self.max_iter,
                                          ncols=80):
            self.epoch = self.iter_train.epoch
            self.iteration = iteration

            ############
            # evaluate #
            ############

            if self.iteration == 0 or self.iter_train.is_new_epoch:
                log = collections.defaultdict(str)
                log_valid = self.evaluate()
                log.update(log_valid)
                log['epoch'] = self.iter_train.epoch
                log['iteration'] = iteration
                with open(osp.join(self.out, 'log.csv'), 'a') as f:
                    f.write(','.join(str(log[h])
                                     for h in self.log_headers) + '\n')
                out_model_dir = osp.join(self.out, 'models')
                if not osp.exists(out_model_dir):
                    os.makedirs(out_model_dir)
                out_model = osp.join(
                    out_model_dir, '%s_epoch%d.h5' %
                    (self.model.__class__.__name__, self.epoch))
                chainer.serializers.save_hdf5(out_model, self.model)

            #########
            # train #
            #########

            in_vars = utils.batch_to_vars(batch,
                                          device=self.device,
                                          volatile=False)
            self.model.zerograds()
            loss = self.model(*in_vars)

            if loss is not None:
                loss.backward()
                self.optimizer.update()
                log = collections.defaultdict(str)
                log_train = {
                    'train/%s' % k: v
                    for k, v in self.model.log.items()
                }
                log['epoch'] = self.iter_train.epoch
                log['iteration'] = iteration
                log.update(log_train)
                with open(osp.join(self.out, 'log.csv'), 'a') as f:
                    f.write(','.join(str(log[h])
                                     for h in self.log_headers) + '\n')

            if iteration >= self.max_iter:
                break
Ejemplo n.º 3
0
    def train(self, max_iter, interval_eval):
        for iteration, batch in tqdm.tqdm(enumerate(self.iter_train),
                                          desc='train', total=max_iter,
                                          ncols=80):
            self.epoch = self.iter_train.epoch
            self.iteration = iteration

            ############
            # evaluate #
            ############

            log_val = {}
            if iteration % interval_eval == 0:
                log_val = self.evaluate()
                out_model_dir = osp.join(self.out, 'models')
                if not osp.exists(out_model_dir):
                    os.makedirs(out_model_dir)
                out_model = osp.join(
                    out_model_dir, '%s_iter%d.h5' %
                    (self.model.__class__.__name__, self.iteration))
                chainer.serializers.save_hdf5(out_model, self.model)

            #########
            # train #
            #########

            in_vars = utils.batch_to_vars(
                batch, device=self.device, volatile=False)
            self.model.zerograds()
            loss = self.model(*in_vars)

            if loss is not None:
                loss.backward()
                self.optimizer.update()
                log = self.model.log
                log['epoch'] = self.iter_train.epoch
                log['iteration'] = iteration
                log.update(log_val)
                utils.append_log_to_json(self.model.log,
                                         osp.join(self.out, 'log.json'))

            if iteration >= max_iter:
                break
Ejemplo n.º 4
0
    def validate(self, n_viz=9):
        """Validate current model using validation dataset.

        Parameters
        ----------
        n_viz: int
            Number fo visualization.

        Returns
        -------
        log: dict
            Log values.
        """
        iter_valid = copy.copy(self.iter_valid)
        losses, lbl_trues, lbl_preds = [], [], []
        vizs = []
        dataset = iter_valid.dataset
        desc = 'valid [iteration=%08d]' % self.iteration
        for batch in tqdm.tqdm(iter_valid,
                               desc=desc,
                               total=len(dataset),
                               ncols=80,
                               leave=False):
            with chainer.no_backprop_mode(), \
                 chainer.using_config('train', False):
                in_vars = utils.batch_to_vars(batch, device=self.device)
                loss = self.model(*in_vars)
            losses.append(float(loss.data))
            score = self.model.score
            img, lbl_true = zip(*batch)
            lbl_pred = chainer.functions.argmax(score, axis=1)
            lbl_pred = chainer.cuda.to_cpu(lbl_pred.data)
            for im, lt, lp in zip(img, lbl_true, lbl_pred):
                lbl_trues.append(lt)
                lbl_preds.append(lp)
                if len(vizs) < n_viz:
                    im, lt = dataset.untransform(im, lt)
                    viz = utils.visualize_segmentation(
                        lbl_pred=lp,
                        lbl_true=lt,
                        img=im,
                        n_class=self.model.n_class)
                    vizs.append(viz)
        # save visualization
        out_viz = osp.join(self.out, 'visualizations_valid',
                           'iter%08d.jpg' % self.iteration)
        if not osp.exists(osp.dirname(out_viz)):
            os.makedirs(osp.dirname(out_viz))
        viz = fcn.utils.get_tile_image(vizs)
        skimage.io.imsave(out_viz, viz)
        # generate log
        acc = utils.label_accuracy_score(lbl_trues, lbl_preds,
                                         self.model.n_class)
        log = {
            'valid/loss': np.mean(losses),
            'valid/acc': acc[0],
            'valid/acc_cls': acc[1],
            'valid/mean_iu': acc[2],
            'valid/fwavacc': acc[3],
        }
        # finalize
        return log
Ejemplo n.º 5
0
    def train(self):
        """Train the network using the training dataset.

        Parameters
        ----------
        None

        Returns
        -------
        None
        """
        stamp_start = time.time()
        for iteration, batch in tqdm.tqdm(enumerate(self.iter_train),
                                          desc='train',
                                          total=self.max_iter,
                                          ncols=80):
            self.epoch = self.iter_train.epoch
            self.iteration = iteration

            ############
            # validate #
            ############

            if self.interval_validate and \
                    self.iteration % self.interval_validate == 0:
                log = collections.defaultdict(str)
                log_valid = self.validate()
                log.update(log_valid)
                log['epoch'] = self.iter_train.epoch
                log['iteration'] = iteration
                log['elapsed_time'] = time.time() - stamp_start
                with open(osp.join(self.out, 'log.csv'), 'a') as f:
                    f.write(','.join(str(log[h])
                                     for h in self.log_headers) + '\n')
                out_model_dir = osp.join(self.out, 'models')
                if not osp.exists(out_model_dir):
                    os.makedirs(out_model_dir)
                out_model = osp.join(
                    out_model_dir, '%s_iter%08d.npz' %
                    (self.model.__class__.__name__, self.iteration))
                chainer.serializers.save_npz(out_model, self.model)

            #########
            # train #
            #########

            in_vars = utils.batch_to_vars(batch, device=self.device)
            self.model.zerograds()
            loss = self.model(*in_vars)
            score = self.model.score
            lbl_true = zip(*batch)[1]
            lbl_pred = chainer.functions.argmax(score, axis=1)
            lbl_pred = chainer.cuda.to_cpu(lbl_pred.data)
            acc = utils.label_accuracy_score(lbl_true, lbl_pred,
                                             self.model.n_class)

            if loss is not None:
                loss.backward()
                self.optimizer.update()
                log = collections.defaultdict(str)
                log_train = {
                    'train/loss': float(loss.data),
                    'train/acc': acc[0],
                    'train/acc_cls': acc[1],
                    'train/mean_iu': acc[2],
                    'train/fwavacc': acc[3],
                }
                log['epoch'] = self.iter_train.epoch
                log['iteration'] = iteration
                log['elapsed_time'] = time.time() - stamp_start
                log.update(log_train)
                with open(osp.join(self.out, 'log.csv'), 'a') as f:
                    f.write(','.join(str(log[h])
                                     for h in self.log_headers) + '\n')

            if iteration >= self.max_iter:
                break
            if (time.time() - stamp_start) > self.max_elapsed_time:
                break
Ejemplo n.º 6
0
    def train(self, max_iter, interval_eval):
        for iteration, batch in tqdm.tqdm(enumerate(self.iter_train),
                                          desc='train',
                                          total=max_iter,
                                          ncols=80):
            self.epoch = self.iter_train.epoch
            self.iteration = iteration

            ############
            # evaluate #
            ############

            log_val = {}
            if iteration % interval_eval == 0 and iteration != 0:
                log_val = self.evaluate()
                out_model_dir = osp.join(self.out, 'models')
                if not osp.exists(out_model_dir):
                    os.makedirs(out_model_dir)
                out_model = osp.join(
                    out_model_dir, '%s_iter%d.h5' %
                    (self.model[0].__class__.__name__, self.iteration))
                chainer.serializers.save_hdf5(out_model, self.model[0])

            #########
            # train #
            #########

            in_vars = []
            for i in xrange(len(self.device)):
                in_vars.append(
                    utils.batch_to_vars(batch[i],
                                        device=self.device,
                                        volatile=False))

            for i in xrange(len(self.device)):
                self.model[i].zerograds()

            loss = []
            for i in xrange(len(self.device)):
                loss.append(self.model[i](*in_vars))

            for i in xrange(len(self.device)):
                if loss[i] is not None:
                    loss[i].backward()

            # add grad
            for i in xrange(1, len(self.device)):
                if loss[i] is not None:
                    self.model[0].addgrads(self.model[i])

            self.optimizer.update()

            # copy params
            for i in xrange(1, len(self.device)):
                self.model[i].copyparams(self.model[0])

            log = self.model[0].log
            log['epoch'] = self.iter_train.epoch
            log['iteration'] = iteration
            log.update(log_val)
            utils.append_log_to_json(self.model[0].log,
                                     osp.join(self.out, 'log.json'))

            if iteration >= max_iter:
                break