예제 #1
0
    def after_step(self, context):

        if time.time() - self.last_refreshed > self.refresh_interval:
            logger = get_train_logger()
            logger.info('refreshing visdom runner.')
            self._refresh(context)
            self.last_refreshed = time.time()
예제 #2
0
    def after_step(self, context):

        step = context.step
        if step % self.flush_interval == 0:
            logger = get_train_logger()
            logger.info('ScalarLogger flushing data... step=%d' % step)
            self._flush()
예제 #3
0
    def _validate(self, context):

        logger = get_train_logger()

        evaluator = Evaluator(
            net=context.trainer.net,
            test_dataset=self.dataset,
            batch_size=self.batch_size,
            predict_module=self.predict_module,
            metric=self.metric,
            num_workers=self.num_workers,
            is_validating=True,
        )

        if self.cache_dir is not None:
            record_file = os.path.join(self.cache_dir,
                                       'step%07d.csv' % context.step)
            evaluator.set_record_file(record_file)

        metric_eval = evaluator.run()
        try:
            metric_eval = metric_eval.detach().cpu().numpy()
        except AttributeError:
            # metric_eval may be not torch.Tensor
            pass

        logger.info('step=%d | VAL | %s=%.4f' %
                    (context.step, self.name, metric_eval))
        train_recorder = context.trainer.train_recorder
        if train_recorder is not None:
            train_recorder.append(Row(context.step, self.name, metric_eval))

        if self.save_best and (self.larger_is_better
                               == (metric_eval > self.best_metric)):
            # save net
            step = context.step
            try:
                net_name = context.trainer.net.module._name
            except AttributeError:
                net_name = 'net'
            filename = '%s-%07d.pth' % (net_name, step)

            cache_dir = context.trainer.result_dir

            path_to_save = os.path.join(cache_dir, filename)
            logger.info('Saving net: %s' % path_to_save)
            torch.save(context.trainer.net.state_dict(), path_to_save)

            if self.last_saved is not None:
                logger.info('Deleting %s' % self.last_saved)
                os.remove(self.last_saved)
            self.last_saved = path_to_save

            self.best_metric = metric_eval

        context.add_item({self.name: metric_eval})
        # TODO: this might be a problem when using "test" mode in train.
        # e.g. batchnorm with batch_size=1
        context.trainer.net.train()
예제 #4
0
    def before_loop(self, context):

        if context.debug_mode in (DebugMode.DEBUG, DebugMode.DEV):
            self.refresh_interval = 2.

        logger = get_train_logger()
        if self.is_master:

            # set up visdom.
            cmd = 'tmux kill-session -t visdom_server'
            logger.info(cmd)
            os.system(cmd)
            time.sleep(.1)

            cmd = 'tmux new-session -d -s "visdom_server"'
            logger.info(cmd)
            os.system(cmd)
            time.sleep(.1)

            cmd = 'tmux send-keys -t visdom_server ". activate && python -m visdom.server" Enter'
            logger.info(cmd)
            os.system(cmd)
            time.sleep(1.)

        start_trying_connecting = time.time()
        viz = visdom.Visdom(port=VISDOM_PORT, server="http://localhost")
        connected = viz.check_connection()
        while not connected:
            time.sleep(2.)
            viz.close()
            logger.info('Trying connecting to Visdom server.')
            viz = visdom.Visdom(port=VISDOM_PORT, server="http://localhost")
            connected = viz.check_connection()
            if time.time() - start_trying_connecting > 20.:
                break

        if connected:
            logger.info('Visdom client connected.')
        else:
            raise RuntimeError('Connecting to Visdom server failed.')

        self.viz = viz
        self.last_refreshed = time.time()

        # plt.figure()
        fig = plt.figure()
        self.fignumber = fig.number
        logger.info('VisdomRunner: fignumber=%d' % self.fignumber)
예제 #5
0
    def before_loop(self, context):

        if context.debug_mode in (DebugMode.DEV, DebugMode.DEBUG):
            self.flush_interval = 20

        record_file = self.record_file

        if os.path.exists(record_file):
            logger = get_train_logger()
            logger.warning('Deleting %s.' % record_file)
            os.remove(record_file)

        df = pd.DataFrame(columns = ['step', 'type', 'value'])
        with open(record_file, 'w') as f:
            df.to_csv(f, header=True)

        self.first_index = 0
        self.dfs = []
예제 #6
0
    def before_loop(self, context):

        logger = get_train_logger()

        if context.debug_mode == DebugMode.DEBUG:
            self.interval = 10
        elif context.debug_mode == DebugMode.DEV:
            self.interval = 50

        if self.metric is None:
            self.metric = LossMetric(context.trainer.loss_module_class)
        self.larger_is_better = larger_is_better = self.metric.larger_is_better

        self.batch_size = context.trainer.batch_size
        self.num_workers = context.trainer.num_workers

        self.best_metric = -np.inf if larger_is_better else np.inf
        self.last_saved = None

        if self.cache_dir is not None:
            self.cache_dir = cache_dir = os.path.join(self.cache_dir,
                                                      self.name)
            logger.info('Creating %s.' % cache_dir)
            os.makedirs(cache_dir)
예제 #7
0
 def after_loop(self, context):
     logger = get_train_logger()
     logger.info('refreshing visdom runner.')
     self._refresh(context)
     self.last_refreshed = time.time()
예제 #8
0
import os
import torch
from experiment_interface.logger import get_train_logger
from experiment_interface.common import DebugMode

logger = get_train_logger()


class Hook():
    def before_loop(self, context):
        pass

    def before_step(self, context):
        pass

    def after_step(self, context):
        pass

    def after_loop(self, context):
        pass


class StopAtStep(Hook):
    def __init__(self, stop_at):
        self.stop_at = stop_at

    def before_loop(self, context):
        if context.debug_mode == DebugMode.DEBUG:
            self.stop_at = 35
        elif context.debug_mode == DebugMode.DEV:
            self.stop_at = 10000
예제 #9
0
def test_cifar10():

    net = MyCNN()

    logger = get_train_logger()

    train_trnsfrms = transforms.Compose([
        transforms.RandomCrop(28),
        transforms.ColorJitter(brightness=0.2,
                               contrast=0.2,
                               saturation=0.2,
                               hue=0.1),
        transforms.ToTensor(),
    ])

    val_trnsfrms = transforms.Compose([
        transforms.CenterCrop(28),
        transforms.ToTensor(),
    ])

    cache_dir = tempfile.mkdtemp()
    # cache_dir = '/cluster/storage/dpark/cifar10/'
    logger.info('cache_dir: %s' % cache_dir)
    train_dataset = Cifar10TrainDataset(cache_dir,
                                        transform=train_trnsfrms,
                                        download=True)
    val_dataset = Cifar10ValDataset(cache_dir,
                                    transform=val_trnsfrms,
                                    download=False)

    result_dir = tempfile.mkdtemp()
    logger.info('result_dir: %s' % result_dir)

    trainer = Trainer(
        net=net,
        train_dataset=train_dataset,
        batch_size=64,
        loss_module_class=torch.nn.CrossEntropyLoss,
        optimizer=torch.optim.Adam(net.parameters(), lr=0.003),
        result_dir=result_dir,
        log_interval=10,
        num_workers=30,
        max_step=30000,
        val_dataset=val_dataset,
        val_interval=200,
    )

    class_acc_metric = ClassificationAccuracy(category_names=CATEGORY_NAMES)
    accuracy_valhook = ValidationHook(
        dataset=val_dataset,
        interval=200,
        name='val_acc',
        # predict_fn = most_probable_class,
        predict_module=MostProbableClass(),
        metric=class_acc_metric,
        save_best=False,
    )

    trainer.register_val_hook(accuracy_valhook)

    val_lossacc_viz = ValLossAccViz(env='val')
    trainer.register_viz_hook(val_lossacc_viz)

    confmat_viz = ConfMatViz(env='val')
    trainer.register_viz_hook(confmat_viz)

    trainer.run(debug_mode=DebugMode.DEBUG)
예제 #10
0
 def after_loop(self, context):
     logger = get_train_logger()
     logger.info('ScalarLogger flushing data... step=%d' % context.step)
     self._flush()
예제 #11
0
    def __init__(
        self,
        net,
        train_dataset,
        batch_size,
        loss_module_class,
        optimizer,
        result_dir,
        log_file='train.log',
        log_interval=1,
        train_record_file='train_record.csv',
        max_step=None,
        num_workers=None,
        hooks=[],
        val_dataset=None,
        val_interval=None,
    ):

        self.logger = logger = get_train_logger(
            os.path.join(result_dir, log_file))

        # cuda settings
        # - use_cuda
        # - num_gpus
        self.use_cuda = use_cuda = torch.cuda.is_available()
        if use_cuda:
            num_gpus = torch.cuda.device_count()
            assert num_gpus > 0
            logger.info('CUDA device count = %d' % num_gpus)
        else:
            self.num_gpus = 0

        # move the net to a gpu, and setup replicas for using multiple gpus;
        # no change needed for cpu mode.
        device = torch.device('cuda:0') if use_cuda else torch.device('cpu')
        net = net.to(device)
        self.net = torch.nn.DataParallel(net)

        self.train_dataset = train_dataset
        self.batch_size = batch_size

        self.loss_module_class = loss_module_class
        self.optimizer = optimizer  # TODO: implement weight update schedule
        self.result_dir = result_dir
        self.log_interval = log_interval

        if use_cuda and num_workers is None:
            raise ValueError(
                '\'num_workers\' must be int, if cuda is available.')

        if not use_cuda and (num_workers is not None and num_workers > 0):
            logger.warning(
                '\'num_workers=%d\' is ignored and set to zero, because use_cuda is False.'
                % num_workers)
            num_workers = 0

        self.num_workers = num_workers

        self.other_hooks = hooks
        self.validation_hooks = []
        self.viz_hooks = []
        self.valhook_tab = dict()

        # Set up default hooks (given valid arguments).
        #   - ValidationHook
        #   - StopAtStepHook

        if val_dataset is not None:

            if val_interval is None:
                raise ValueError(
                    '\'val_interval\' must be not None, if val_dataset is not None.'
                )

            if isinstance(val_dataset, torch.utils.data.Dataset):
                name = 'val_loss'
                dataset = val_dataset
                # predict_fn, metric = None, None
            elif len(val_dataset) == 2 and \
                isinstance(val_dataset[0], str) and isinstance(val_dataset[1], torch.utils.data.Dataset):
                name, dataset = val_dataset
            else:
                raise ValueError('Invalid format for \'val_dataset\'.')

            val_hook = ValidationHook(dataset,
                                      val_interval,
                                      name,
                                      save_best=True)
            self.register_val_hook(val_hook)

        if max_step is not None:
            self.other_hooks.append(StopAtStep(max_step))

        if not train_record_file.endswith('.csv'):
            raise ValueError('train_record_file must have .csv extension.')

        train_record_file = os.path.join(result_dir, train_record_file)
        train_recorder = ScalarRecorder(train_record_file)
        self.train_record_file = train_record_file
        self.train_recorder = train_recorder

        # trainloss_viz_runner = VisdomRunner(plot_fn=plot_trainval_loss)
        trainval_loss_viz = TrainValLossViz(is_master=True)
        self.register_viz_hook(trainval_loss_viz)
예제 #12
0
    def __init__(
        self,
        net,
        test_dataset,
        batch_size,
        predict_module,
        metric,
        num_workers,
        result_dir=None,
        record_file=None,
        log_file=None,
        pretrained_params_file=None,
        is_validating=False,
    ):

        if is_validating:
            logger = get_train_logger()
        else:
            if log_file is not None:
                if result_dir is None:
                    raise ValueError(
                        '\'result_dir\' must be not None, if \'log_file\' it not None. '
                    )
                logger = get_test_logger(os.path.join(result_dir, log_file))
            else:
                logger = get_test_logger(None)
        self.logger = logger

        self.use_cuda = use_cuda = torch.cuda.is_available()
        if use_cuda:
            num_gpus = torch.cuda.device_count()
            assert num_gpus > 0
            logger.info('CUDA device count = %d' % num_gpus)
        else:
            self.num_gpus = 0

        if is_validating:
            # use input net as it is.
            self.net = net
        else:
            device = torch.device('cuda:0') if use_cuda else torch.device(
                'cpu')
            # move the net to a gpu, and setup replicas for using multiple gpus;
            # no change needed for cpu mode.
            net = net.to(device)
            self.net = torch.nn.DataParallel(net)

        if pretrained_params_file is not None:
            # (TODO) load the pretrained model
            import pdb
            pdb.set_trace()

        self.test_dataset = test_dataset
        self.batch_size = batch_size

        self.predict_module = predict_module
        self.metric = metric

        if use_cuda and num_workers is None:
            raise ValueError(
                '\'num_workers\' must be int, if cuda is available.')

        if not use_cuda and (num_workers is not None and num_workers > 0):
            logger.warning(
                '\'num_workers=%d\' is ignored and set to zero, because use_cuda is False.'
                % num_workers)
            num_workers = 0

        self.num_workers = num_workers

        if record_file is not None:
            if result_dir is None:
                raise ValueError(
                    '\'result_dir\' must be not None, if \'record_file\' is not None. '
                )
            record_file = os.path.join(result_dir, record_file)
        self.record_file = record_file