Ejemplo n.º 1
0
    def setUp(self):

        self.C = mlogger.Container()
        self.metric_a = mlogger.metric.Simple()
        self.config = mlogger.Config(entry="running_a_test")
        self.CC = mlogger.Container(b=mlogger.metric.Average())
        self.CCC = mlogger.Container(c=mlogger.metric.Timer(),
                                     d=mlogger.metric.Maximum())
Ejemplo n.º 2
0
def setup_xp(args, model, optimizer):

    env_name = args.xp_name.split('/')[-1]
    if args.visdom:
        plotter = mlogger.VisdomPlotter({'env': env_name, 'server': args.server, 'port': args.port})
    else:
        plotter = None

    xp = mlogger.Container()

    xp.config = mlogger.Config(plotter=plotter, **vars(args))

    xp.epoch = mlogger.metric.Simple()

    xp.train = mlogger.Container()
    xp.train.acc = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy", plot_legend="training")
    xp.train.loss = mlogger.metric.Average(plotter=plotter, plot_title="Objective", plot_legend="loss")
    xp.train.obj = mlogger.metric.Simple(plotter=plotter, plot_title="Objective", plot_legend="objective")
    xp.train.reg = mlogger.metric.Simple(plotter=plotter, plot_title="Objective", plot_legend="regularization")
    xp.train.weight_norm = mlogger.metric.Simple(plotter=plotter, plot_title="Weight-Norm")
    xp.train.step_size = mlogger.metric.Average(plotter=plotter, plot_title="Step-Size", plot_legend="clipped")
    xp.train.step_size_u = mlogger.metric.Average(plotter=plotter, plot_title="Step-Size", plot_legend="unclipped")
    xp.train.timer = mlogger.metric.Timer(plotter=plotter, plot_title="Time", plot_legend='training')

    xp.val = mlogger.Container()
    xp.val.acc = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy", plot_legend="validation")
    xp.val.timer = mlogger.metric.Timer(plotter=plotter, plot_title="Time", plot_legend='validation')

    xp.max_val = mlogger.metric.Maximum(plotter=plotter, plot_title="Accuracy", plot_legend='best-validation')

    xp.test = mlogger.Container()
    xp.test.acc = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy", plot_legend="test")
    xp.test.timer = mlogger.metric.Timer(plotter=plotter, plot_title="Time", plot_legend='test')

    if args.dataset == "imagenet":
        xp.train.acc5 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@5", plot_legend="training")
        xp.val.acc5 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@5", plot_legend="validation")
        xp.test.acc5 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@5", plot_legend="test")

    if args.visdom:
        plotter.set_win_opts("Step-Size", {'ytype': 'log'})
        plotter.set_win_opts("Objective", {'ytype': 'log'})

    if args.log:
        # log at each epoch
        xp.epoch.hook_on_update(lambda: xp.save_to('{}/results.json'.format(args.xp_name)))
        xp.epoch.hook_on_update(lambda: save_state(model, optimizer, '{}/model.pkl'.format(args.xp_name)))

        # log after final evaluation on test set
        xp.test.acc.hook_on_update(lambda: xp.save_to('{}/results.json'.format(args.xp_name)))
        xp.test.acc.hook_on_update(lambda: save_state(model, optimizer, '{}/model.pkl'.format(args.xp_name)))

        # save results and model for best validation performance
        xp.max_val.hook_on_new_max(lambda: save_state(model, optimizer, '{}/best_model.pkl'.format(args.xp_name)))

    return xp
Ejemplo n.º 3
0
def get_xp(args, optimizer):

    if args.visdom:
        plotter = mlogger.VisdomPlotter({
            'env': args.xp_name,
            'server': 'http://localhost',
            'port': args.port
        })
    else:
        plotter = None

    xp = mlogger.Container()

    xp.config = mlogger.Config(plotter=plotter, **vars(args))

    xp.epoch = mlogger.metric.Simple()

    xp.train = mlogger.Container()
    xp.train.error = mlogger.metric.Average(plotter=plotter,
                                            plot_title="Error",
                                            plot_legend="train")
    xp.train.obj = mlogger.metric.Average(plotter=plotter,
                                          plot_title="Objective",
                                          plot_legend="objective")
    xp.train.timer = mlogger.metric.Timer(plotter=plotter,
                                          plot_title="Time",
                                          plot_legend='training')

    xp.val = mlogger.Container()
    xp.val.error = mlogger.metric.Average(plotter=plotter,
                                          plot_title="Error",
                                          plot_legend="validation")
    xp.val.timer = mlogger.metric.Timer(plotter=plotter,
                                        plot_title="Time",
                                        plot_legend='validation')

    xp.test = mlogger.Container()
    xp.test.error = mlogger.metric.Average(plotter=plotter,
                                           plot_title="Error",
                                           plot_legend="test")
    xp.test.timer = mlogger.metric.Timer(plotter=plotter,
                                         plot_title="Time",
                                         plot_legend='test')
    return xp
Ejemplo n.º 4
0
    def test_load_state_dict(self):
        self.C.a = self.metric_a
        self.C.conf = self.config
        self.C.CC = self.CC
        self.C.CC.CCC = self.CCC

        self.metric_a.update(10)
        self.CC.b.update(12)
        self.CC.b.update(15)
        self.CCC.c.update()
        self.CCC.d.update(12)
        self.CCC.d.update(15)

        state = self.C.state_dict()
        new_C = mlogger.Container()
        new_C.load_state_dict(state)

        self.assertDictEqual(self.C.state_dict(), new_C.state_dict())

        for old, new in zip(self.C.children(), new_C.children()):
            assert old is not new
            assert isinstance(new, type(old))
            self.assertDictEqual(old.state_dict(), new.state_dict())
Ejemplo n.º 5
0
use_visdom = True
lr = 0.01
n_epochs = 10

#----------------------------------------------------------
# Prepare logging
#----------------------------------------------------------

# log the hyperparameters of the experiment
if use_visdom:
    plotter = mlogger.VisdomPlotter({'env': 'my_experiment', 'server': 'http://localhost', 'port': 8097},
                                   manual_update=True)
else:
    plotter = None

xp = mlogger.Container()

xp.config = mlogger.Config(plotter=plotter)
xp.config.update(lr=lr, n_epochs=n_epochs)

xp.epoch = mlogger.metric.Simple()

xp.train = mlogger.Container()
xp.train.acc1 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@1", plot_legend="training")
xp.train.acck = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@k", plot_legend="training")
xp.train.loss = mlogger.metric.Average(plotter=plotter, plot_title="Objective")
xp.train.timer = mlogger.metric.Timer(plotter=plotter, plot_title="Time", plot_legend="training")

xp.val = mlogger.Container()
xp.val.acc1 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@1", plot_legend="validation")
xp.val.acck = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@k", plot_legend="validation")