Esempio n. 1
0
    def start(self, runName, model, recorder, *args, **kwargs):
        """Starts the training and encapsulates it into a safe environement.
		If the training stops because of an Exception or SIGTEM, the trainer
		will save logs, the store, and the last version of the model.
		"""

        import simplejson, signal, cPickle

        def _handler_sig_term(sig, frame):
            _dieGracefully("SIGTERM", None)
            sys.exit(sig)

        def _dieGracefully(exType, tb=None):
            if type(exType) is types.StringType:
                exName = exType
            else:
                exName = exType.__name__

            death_time = time.ctime().replace(' ', '_')
            filename = "dx-xb_" + runName + "_death_by_" + exName + "_" + death_time
            sys.stderr.write(
                "\n===\nDying gracefully from %s, and saving myself to:\n...%s\n===\n"
                % (exName, filename))
            model.save(filename)
            f = open(filename + ".traceback.log", 'w')
            f.write(
                "Mariana training Interruption\n=============================\n"
            )
            f.write("\nDetails\n-------\n")
            f.write("Name: %s\n" % runName)
            f.write("pid: %s\n" % os.getpid())
            f.write("Killed by: %s\n" % str(exType))
            f.write("Time of death: %s\n" % death_time)
            f.write("Model saved to: %s\n" % filename)
            sstore = str(self.store).replace("'", '"').replace(
                "True", 'true').replace("False", 'false')
            try:
                f.write("store:\n%s" % json.dumps(json.loads(sstore),
                                                  sort_keys=True,
                                                  indent=4,
                                                  separators=(',', ': ')))
            except Exception as e:
                print "Warning: Couldn't format the store to json, saving it ugly."
                print "Reason:", e
                f.write(sstore)

            if tb is not None:
                f.write("\nTraceback\n---------\n")
                f.write(
                    str(traceback.extract_tb(tb)).replace(
                        "), (", "),\n(").replace("[(",
                                                 "[\n(").replace(")]", ")\n]"))
            f.flush()
            f.close()
            f = open(filename + ".store.pkl", "wb")
            cPickle.dump(self.store, f)
            f.close()

        signal.signal(signal.SIGTERM, _handler_sig_term)
        if MSET.VERBOSE:
            print "\n" + "Training starts."
        MCAN.friendly("Process id", "The pid of this run is: %d" % os.getpid())

        if recorder == "default":
            params = {"printRate": 1, "writeRate": 1}
            recorder = MREC.GGPlot2(runName, **params)
            MCAN.friendly(
                "Default recorder",
                "The trainer will recruit the default 'GGPlot2' recorder with the following arguments:\n\t %s.\nResults will be saved into '%s'."
                % (params, recorder.filename))

        try:
            return self.run(runName, model, recorder, *args, **kwargs)
        except MSTOP.EndOfTraining as e:
            print e.message
            death_time = time.ctime().replace(' ', '_')
            filename = "finished_" + runName + "_" + death_time
            f = open(filename + ".stopreason.txt", 'w')
            f.write("Name: %s\n" % runName)
            f.write("pid: %s\n" % os.getpid())
            f.write("Time of death: %s\n" % death_time)
            f.write("Epoch of death: %s\n" % self.store["runInfos"]["epoch"])
            f.write("Stopped by: %s\n" % e.stopCriterion.name)
            f.write("Reason: %s\n" % e.message)
            sstore = str(self.store).replace("'", '"').replace(
                "True", 'true').replace("False", 'false')
            try:
                f.write("store:\n%s" % json.dumps(json.loads(sstore),
                                                  sort_keys=True,
                                                  indent=4,
                                                  separators=(',', ': ')))
            except Exception as e:
                print "Warning: Couldn't format the store to json, saving it ugly."
                print "Reason:", e
                f.write(sstore)

            f.flush()
            f.close()
            model.save(filename)
            f = open(filename + ".store.pkl", "wb")
            cPickle.dump(self.store, f)
            f.close()

        except KeyboardInterrupt:
            if not self.saveIfMurdered:
                raise
            exType, ex, tb = sys.exc_info()
            _dieGracefully(exType, tb)
            raise
        except:
            if not self.saveIfMurdered:
                raise
            exType, ex, tb = sys.exc_info()
            _dieGracefully(exType, tb)
            raise
Esempio n. 2
0
    testMaps.mapOutput(o, testData.numbers)

    validationData = MDM.Series(images=validation_set[0],
                                numbers=validation_set[1])
    validationMaps = MDM.DatasetMapper()
    validationMaps.mapInput(i, validationData.images)
    validationMaps.mapOutput(o, validationData.numbers)

    earlyStop = MSTOP.GeometricEarlyStopping(testMaps,
                                             patience=100,
                                             patienceIncreaseFactor=1.1,
                                             significantImprovement=0.00001,
                                             outputFunction="score",
                                             outputLayer=o)
    epochWall = MSTOP.EpochWall(1000)

    trainer = MT.DefaultTrainer(trainMaps=trainMaps,
                                testMaps=testMaps,
                                validationMaps=validationMaps,
                                stopCriteria=[earlyStop, epochWall],
                                testFunctionName="testAndAccuracy",
                                validationFunctionName="testAndAccuracy",
                                trainMiniBatchSize=20,
                                saveIfMurdered=False)

    recorder = MREC.GGPlot2("MLP",
                            whenToSave=[MREC.SaveMin("test", o.name, "score")],
                            printRate=1,
                            writeRate=1)
    trainer.start("MLP", mlp, recorder=recorder)