Beispiel #1
0
    def act(self, state, action):
        # decoding the numpy.float action to {key, value} action
        action = state.environment[action]
        new_masks = self.__new_pruned_mask(state, action)

        # validate the model
        self.model.set_new_masks(new_masks)
        if not self.model.masks[0].weight.is_cuda:
            self.model.masks = self.model.masks.to(self.device)
        valid_loss, accuracy = validation(self.model, self.valid_loader,
                                          self.criterion)

        # compute the reward
        reward = self.__compute_reward(state, valid_loss, accuracy)
        for mask in self.model.masks:
            mask = mask.to('cpu')
        # self.model.masks[0] = self.model.masks[0].to('cpu')
        # self.model.masks[1] = self.model.masks[1].to('cpu')
        # self.model.masks[2] = self.model.masks[2].to('cpu')

        # new subenvironment
        new_env = deepcopy(state.environment)
        new_env.remove(action)

        new_state = State(new_masks, new_env, reward, state.n_prunes)
        new_state.prune()

        # Verify if the game has winned
        if (new_state.remaining_weights <= self.remaining_weights):
            self.is_done = True
            reward += 1000

        return new_state, reward, self.is_done
Beispiel #2
0
class TestState(unittest.TestCase):
    def setUp(self):
        self.file_location = tempfile.mkstemp()[1]
        self.state = State(self.file_location)
        self.mode = Mode(self.state)

    def tearDown(self):
        os.remove(self.file_location)

    def test_get_set_mode(self):
        self.mode.set_mode(MODE.NORMAL)
        assert_that(self.mode.get_mode(), equal_to(MODE.NORMAL))

    def test_get_default_mode(self):
        assert_that(self.mode.get_mode(), equal_to(MODE.NORMAL))

    @raises(AssertionError)
    def test_set_invalid_type(self):
        self.mode.set_mode("NORMAL")

    @raises(ValueError)
    def test_get_invalid_type(self):
        # State.set_mode() will check the type of parameter. State.set() does
        # not check parameter. The value could be corrupted. State.get_mode()
        # will raise ValueError.
        self.state.set(Mode.MODE_KEY, "123")
        self.mode.get_mode()

    @raises(ModeTransitionError)
    def test_invalid_transition(self):
        try:
            self.mode.set_mode(MODE.MAINTENANCE, [MODE.ENTERING_MAINTENANCE])
        except ModeTransitionError as e:
            assert_that(e.from_mode, equal_to(MODE.NORMAL))
            raise

    def test_transition_to_self(self):
        # Should allow to change mode from A to A (e.g. NORMAL to NORMAL)
        assert_that(self.mode.get_mode(), equal_to(MODE.NORMAL))
        self.mode.set_mode(MODE.NORMAL, [MODE.MAINTENANCE])
        assert_that(self.mode.get_mode(), equal_to(MODE.NORMAL))

    def test_callbacks(self):
        enter_maintenance = MagicMock()
        exit_normal = MagicMock()
        change = MagicMock()
        self.mode.on_enter_mode(MODE.MAINTENANCE, enter_maintenance)
        self.mode.on_exit_mode(MODE.NORMAL, exit_normal)
        self.mode.on_change(change)

        self.mode.set_mode(MODE.ENTERING_MAINTENANCE)
        assert_that(enter_maintenance.called, equal_to(False))
        assert_that(exit_normal.called, equal_to(True))
        assert_that(change.call_count, equal_to(1))

        self.mode.set_mode(MODE.MAINTENANCE)
        assert_that(enter_maintenance.called, equal_to(True))
        assert_that(change.call_count, equal_to(2))
    def loop(self):
        """
        Main loop for training and testing, saving ...
        """

        while self.epoch < self.args.epochs:
            log('[Training] %s' % self.scheduler.report())

            # Note that we test first, to also get the error of the untrained model.
            testing = elapsed(functools.partial(self.test))
            training = elapsed(functools.partial(self.train))
            log('[Training] %gs training, %gs testing' % (training, testing))

            if self.args.early_stopping:
                validation = elapsed(functools.partial(self.validate))
                log('[Training] %gs validation' % validation)

            # Save model checkpoint after each epoch.
            utils.remove(self.args.state_file + '.%d' % (self.epoch - 1))
            State.checkpoint(self.model, self.scheduler.optimizer, self.epoch,
                             self.args.state_file + '.%d' % self.epoch)
            log('[Training] %d: checkpoint' % self.epoch)
            torch.cuda.empty_cache()  # necessary?

            # Save statistics and plots.
            if self.args.training_file:
                utils.write_hdf5(self.args.training_file,
                                 self.train_statistics)
                log('[Training] %d: wrote %s' %
                    (self.epoch, self.args.training_file))
            if self.args.testing_file:
                utils.write_hdf5(self.args.testing_file, self.test_statistics)
                log('[Training] %d: wrote %s' %
                    (self.epoch, self.args.testing_file))

            if utils.display():
                self.plot()
            self.epoch += 1  # !

        # Final testing.
        testing = elapsed(functools.partial(self.test))
        log('[Training] %gs testing' % (testing))

        # Save model checkpoint after each epoch.
        utils.remove(self.args.state_file + '.%d' % (self.epoch - 1))
        State.checkpoint(self.model, self.scheduler.optimizer, self.epoch,
                         self.args.state_file)
        log('[Training] %d: checkpoint' % self.epoch)

        self.results = {
            'training_statistics': self.train_statistics,
            'testing_statistics': self.test_statistics,
        }
        if self.args.results_file:
            utils.write_pickle(self.args.results_file, self.results)
            log('[Training] wrote %s' % self.args.results_file)
 def setup(self):
     self._threadpool = ThreadPoolExecutor(16)
     self.state_file = tempfile.mktemp()
     common.services.register(ThreadPoolExecutor, self._threadpool)
     self._host_handler = MagicMock()
     common.services.register(Host.Iface, self._host_handler)
     common.services.register(ServiceName.MODE,
                              Mode(State(self.state_file)))
     common.services.register(ServiceName.AGENT_CONFIG, MagicMock())
     common.services.register(ServiceName.DATASTORE_TAGS,
                              DatastoreTags(State(self.state_file)))
Beispiel #5
0
def get_state():
    """ Requests and returns the state from the atmega """
    atmega.write("c\n")
    # get the line, strip the newline char, split it
    arr = atmega.readline().strip().split(" ")
    if arr[0] != "s":
        logger.error("Expected reply starting with 's' from control when asked"
                     " for state, got {}".format(arr[0]))
        return State(*([0] * 6))
    # exclude the first 's' char
    arr = [float(i) for i in arr[1:]]
    return State(*arr)
    def compute_new_state(self, state, acceleration, delta_time):
        "Compute the forward kinematics with finite difference method."

        new_state = State(state.ndim)
        
        # Velocity (m/s) at time_n+1
        new_state.velocity = state.velocity + acceleration * delta_time

        # Position (m) at time_n+1
        new_state.position = state.position + state.velocity * delta_time
        #new_state.position = state.position + new_state.velocity * delta_time

        return new_state
Beispiel #7
0
    def compute_new_state(self, state, acceleration, delta_time):
        "Compute the forward kinematics with finite difference method."

        new_state = State(state.ndim)

        # Velocity (m/s) at time_n+1
        new_state.velocity = state.velocity + acceleration * delta_time

        # Position (m) at time_n+1
        new_state.position = state.position + state.velocity * delta_time
        #new_state.position = state.position + new_state.velocity * delta_time

        return new_state
 def setup(self):
     self.state_file = tempfile.mktemp()
     self._host_handler = MagicMock()
     common.services.register(Host.Iface, self._host_handler)
     common.services.register(ServiceName.MODE,
                              Mode(State(self.state_file)))
     common.services.register(ServiceName.AGENT_CONFIG, MagicMock())
Beispiel #9
0
    def load_models(self):
        """
        Load models.
        """

        self.N_class = numpy.max(self.test_codes) + 1
        network_units = list(map(int, self.args.network_units.split(',')))
        log('[Testing] using %d input channels' % self.test_images.shape[3])
        self.model = models.Classifier(
            self.N_class,
            resolution=(self.test_images.shape[3], self.test_images.shape[1],
                        self.test_images.shape[2]),
            architecture=self.args.network_architecture,
            activation=self.args.network_activation,
            batch_normalization=not self.args.network_no_batch_normalization,
            start_channels=self.args.network_channels,
            dropout=self.args.network_dropout,
            units=network_units)
        assert os.path.exists(
            self.args.classifier_file
        ), 'state file %s not found' % self.args.classifier_file
        state = State.load(self.args.classifier_file)
        log('[Testing] read %s' % self.args.classifier_file)

        self.model.load_state_dict(state.model)
        if self.args.use_gpu and not cuda.is_cuda(self.model):
            log('[Testing] classifier is not CUDA')
            self.model = self.model.cuda()
        log('[Testing] loaded classifier')

        # !
        self.model.eval()
        log('[Testing] set classifier to eval')
Beispiel #10
0
    def _register_services(self):
        common.services.register(ServiceName.AGENT_CONFIG, self._config)
        common.services.register(ServiceName.LOCKED_VMS, ExclusiveSet())

        threadpool = RequestIdExecutor(
            ThreadPoolExecutor(self._config.workers))
        common.services.register(ThreadPoolExecutor, threadpool)

        self._registrant = ChairmanRegistrant(self._config.chairman_list)
        self._config.on_config_change(self._config.CHAIRMAN,
                                      self._registrant.update_chairman_list)
        common.services.register(ServiceName.REGISTRANT, self._registrant)

        state_json_file = os.path.join(
            self._config.options.config_path,
            self._config.DEFAULT_STATE_FILE)
        state = State(state_json_file)

        mode = Mode(state)
        mode.on_change(self._registrant.trigger_chairman_update)
        common.services.register(ServiceName.MODE, mode)

        ds_tags = DatastoreTags(state)
        ds_tags.on_change(self._registrant.trigger_chairman_update)
        common.services.register(ServiceName.DATASTORE_TAGS, ds_tags)
    def __init__(self, id, networks, datastores, cpu, mem, disk, overcommit):
        self.id = id
        self.cpu = cpu
        self.mem = mem
        self.disk = disk
        self.parent = ""
        self.constraints = set()
        host_constraint = ResourceConstraint(ResourceConstraintType.HOST,
                                             ["host-" + str(id)])
        self.constraints.add(host_constraint)
        [self.constraints.add(net) for net in networks]
        [self.constraints.add(ds) for ds in datastores]
        self.address = ""
        self.port = ""
        conf_dir = mkdtemp(delete=True)
        state = State(os.path.join(conf_dir, "state.json"))
        common.services.register(ServiceName.MODE, Mode(state))
        self.hv = self._get_hypervisor_instance(
            id, cpu, mem, disk, [ds.values[0] for ds in datastores],
            [network.values[0] for network in networks], overcommit)

        # need agent_config for create/delete vm.
        agent_config = AgentConfig([
            "--config-path", conf_dir, "--hostname", "localhost", "--port",
            "1234", "--host-id", id
        ])
        common.services.register(ServiceName.AGENT_CONFIG, agent_config)
        super(Host, self).__init__(self.hv)
 def setUp(self):
     self.agent_conf_dir = mkdtemp(delete=True)
     state = State(os.path.join(self.agent_conf_dir, "state.json"))
     common.services.register(ServiceName.MODE, Mode(state))
     common.services.register(ServiceName.DATASTORE_TAGS,
                              DatastoreTags(state))
     self.agent = AgentConfig(["--config-path", self.agent_conf_dir])
Beispiel #13
0
 def update_state(config, value, state, update_values):
     if not state:
         return State(update_values(value, []),
                      expired_time=config.property('expired_time'))
     elif state.is_expired():
         return None
     else:
         return state.update(value, update_values)
Beispiel #14
0
def choose_current_state():
    Particles.probability[Particles.probability < .0001] = .0001
    stack = np.stack((Particles.x, Particles.y, Particles.probability), 0)
    np.save("data.csv", stack)
    x = np.average(Particles.x, weights=Particles.probability)
    y = np.average(Particles.y, weights=Particles.probability)
    modeling.current_state = State(x, y, Particles.depth, Particles.yaw, 0, 0)
    return
    def setUp(self):
        self.agent_conf_dir = mkdtemp(delete=True)
        state = State(os.path.join(self.agent_conf_dir, "state.json"))
        common.services.register(ServiceName.MODE, Mode(state))
        common.services.register(ServiceName.DATASTORE_TAGS,
                                 DatastoreTags(state))
        self.multi_agent = MultiAgent(2200, AgentConfig.DEFAULT_CONFIG_PATH,
                                      AgentConfig.DEFAULT_CONFIG_FILE)

        self.agent = AgentConfig(["--config-path", self.agent_conf_dir])
Beispiel #16
0
def q_table_loader(path):
    tsv = pd.read_csv(path, sep='\t')
    q_table = dict()
    for i in range(tsv.shape[0]):
        string = tsv.iloc[i, 0]
        st = State(string)
        values = list(tsv.iloc[i, 1:])
        q_table[st] = values

    return q_table
Beispiel #17
0
    def _register_services(self):
        common.services.register(ServiceName.AGENT_CONFIG, self._config)
        common.services.register(ServiceName.LOCKED_VMS, ExclusiveSet())

        state_json_file = os.path.join(self._config.options.config_path,
                                       self._config.DEFAULT_STATE_FILE)
        state = State(state_json_file)

        mode = Mode(state)
        common.services.register(ServiceName.MODE, mode)
Beispiel #18
0
    def __init__(self):
        self.state = State()

        self.cmd_queue = DeferredQueue()
        self.x1_queue = DeferredQueue()
        self.x1CliFac = X1ClientFactory(self.cmd_queue, self.x1_queue,
                                        self.state)
        self.x1tcp = None

        self.x2CliFac = X2ServerFactory()
        self._start_x2_server()
    def validate(self):
        """
        Validate for early stopping.
        """

        self.model.eval()
        log('[Training] %d set classifier to eval' % self.epoch)
        assert self.model.training is False

        loss = 0
        error = 0
        num_batches = int(
            math.ceil(self.val_images.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            perm = numpy.take(range(self.val_images.shape[0]),
                              range(b * self.args.batch_size,
                                    (b + 1) * self.args.batch_size),
                              mode='clip')
            batch_images = common.torch.as_variable(self.val_images[perm],
                                                    self.args.use_gpu)
            batch_classes = common.torch.as_variable(self.val_codes[perm],
                                                     self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_classes = self.model(batch_images)

            e = self.loss(batch_classes, output_classes)
            loss += e.item()
            e = self.error(batch_classes, output_classes)
            error += e.item()

        loss /= num_batches
        error /= num_batches
        log('[Training] %d: val %g (%g)' % (self.epoch, loss, error))

        if self.val_error is None or error < self.val_error:
            self.val_error = error
            State.checkpoint(self.model, self.scheduler.optimizer, self.epoch,
                             self.args.state_file + '.es')
            log('[Training] %d: early stopping checkoint' % self.epoch)
Beispiel #20
0
 def setUp(self):
     self.tr = proto_helpers.StringTransportWithDisconnection()
     self.clock = task.Clock()
     cmd_queue = DeferredQueue()
     x1_queue = DeferredQueue()
     state = State()
     factory = X1ClientFactory(cmd_queue, x1_queue, state)
     self.proto = factory.buildProtocol(('127.0.0.1', 0))
     self.proto.callLater = self.clock.callLater
     self.tr.protocol = self.proto
     config.pingEnable = False
     self.cmd_data = []
    def test_persistent(self):
        t1 = {
            "datastore1": set(["tag1", "tag2"]),
            "datastore2": set(["tag3", "tag2"]),
        }
        self.tags.set(t1)

        # Load from new State and DatastoreTags object. Verify the same result.
        state = State(self.file_location)
        tags = DatastoreTags(state)
        t2 = tags.get()
        assert_that(t1, equal_to(t2))
Beispiel #22
0
    def main(self):
        """
        Main which should be overwritten.
        """

        self.test_images = utils.read_hdf5(self.args.test_images_file).astype(
            numpy.float32)
        log('[Testing] read %s' % self.args.test_images_file)

        # For handling both color and gray images.
        if len(self.test_images.shape) < 4:
            self.test_images = numpy.expand_dims(self.test_images, axis=3)
            log('[Testing] no color images, adjusted size')
        self.resolution = self.test_images.shape[2]
        log('[Testing] resolution %d' % self.resolution)

        self.test_codes = utils.read_hdf5(self.args.test_codes_file).astype(
            numpy.int)
        self.test_codes = self.test_codes[:, self.args.label_index]
        log('[Testing] read %s' % self.args.test_codes_file)

        N_class = numpy.max(self.test_codes) + 1
        network_units = list(map(int, self.args.network_units.split(',')))
        log('[Testing] using %d input channels' % self.test_images.shape[3])
        self.model = models.Classifier(
            N_class,
            resolution=(self.test_images.shape[3], self.test_images.shape[1],
                        self.test_images.shape[2]),
            architecture=self.args.network_architecture,
            activation=self.args.network_activation,
            batch_normalization=not self.args.network_no_batch_normalization,
            start_channels=self.args.network_channels,
            dropout=self.args.network_dropout,
            units=network_units)

        assert os.path.exists(
            self.args.state_file
        ), 'state file %s not found' % self.args.state_file
        state = State.load(self.args.state_file)
        log('[Testing] read %s' % self.args.state_file)

        self.model.load_state_dict(state.model)
        if self.args.use_gpu and not cuda.is_cuda(self.model):
            log('[Testing] model is not CUDA')
            self.model = self.model.cuda()
        log('[Testing] loaded model')

        self.model.eval()
        log('[Testing] set classifier to eval')

        self.test()
Beispiel #23
0
    def _register_services(self):
        common.services.register(ServiceName.AGENT_CONFIG, self._config)
        common.services.register(ServiceName.LOCKED_VMS, ExclusiveSet())

        threadpool = RequestIdExecutor(ThreadPoolExecutor(
            self._config.workers))
        common.services.register(ThreadPoolExecutor, threadpool)

        state_json_file = os.path.join(self._config.options.config_path,
                                       self._config.DEFAULT_STATE_FILE)
        state = State(state_json_file)

        mode = Mode(state)
        common.services.register(ServiceName.MODE, mode)
    def test_persist_state(self):
        normal_state = "normal"
        maintenance_state = "maintenance"

        # test retrieving persisted value
        state = State(self.file_location)
        state.set("state", normal_state)
        s = state.get("state")
        assert_that(s, equal_to(normal_state))

        # test retrieving updated persisted value
        state.set("state", maintenance_state)
        s = state.get("state")
        assert_that(s, equal_to(maintenance_state))

        # test retrieving non-existing key/value pair
        s = state.get("non_state")
        assert_that(s, is_(None))

        # test read from persisted state
        state = State(self.file_location)
        s = state.get("state")
        assert_that(s, equal_to(maintenance_state))
    def load_model(self):
        """
        Load model.
        """

        database = utils.read_hdf5(self.args.database_file).astype(numpy.float32)
        log('[Attack] read %sd' % self.args.database_file)

        self.N_font = database.shape[0]
        self.N_class = database.shape[1]
        resolution = database.shape[2]

        database = database.reshape((database.shape[0] * database.shape[1], database.shape[2], database.shape[3]))
        database = torch.from_numpy(database)
        if self.args.use_gpu:
            database = database.cuda()
        database = torch.autograd.Variable(database, False)

        N_theta = self.test_theta.shape[1]
        log('[Attack] using %d N_theta' % N_theta)
        decoder = models.AlternativeOneHotDecoder(database, self.N_font, self.N_class, N_theta)
        decoder.eval()

        image_channels = 1 if N_theta <= 7 else 3
        network_units = list(map(int, self.args.network_units.split(',')))
        log('[Attack] using %d input channels' % image_channels)
        classifier = models.Classifier(self.N_class, resolution=(image_channels, resolution, resolution),
                                       architecture=self.args.network_architecture,
                                       activation=self.args.network_activation,
                                       batch_normalization=not self.args.network_no_batch_normalization,
                                       start_channels=self.args.network_channels,
                                       dropout=self.args.network_dropout,
                                       units=network_units)

        assert os.path.exists(self.args.classifier_file), 'state file %s not found' % self.args.classifier_file
        state = State.load(self.args.classifier_file)
        log('[Attack] read %s' % self.args.classifier_file)

        classifier.load_state_dict(state.model)
        if self.args.use_gpu and not cuda.is_cuda(classifier):
            log('[Attack] classifier is not CUDA')
            classifier = classifier.cuda()
        log('[Attack] loaded classifier')

        # !
        classifier.eval()
        log('[Attack] set classifier to eval')

        self.model = models.DecoderClassifier(decoder, classifier)
Beispiel #26
0
    def run(self):
        """The main loop"""

        state = State(1)
        time = 0

        while time < 5:  #TODO

            time = time + self.delta_time

            # Update state (physics)
            acceleration = self.model.compute_acceleration(state)
            state = self.kinematics.compute_new_state(state, acceleration, self.delta_time)

            print time, acceleration, state.velocity[0], state.position[0]
    def setUp(self):
        self.hostname = "localhost"
        self.host_port = 1234
        self.availability_zone_id = "test"
        self.host_addr = ServerAddress(self.hostname, self.host_port)
        self.chairman_list = []
        self.agent_id = "foo"
        self.host_config = HostConfig(self.agent_id, self.availability_zone_id,
                                      [Datastore("bar")], self.host_addr,
                                      [Network("nw1")])
        self.registrant = ChairmanRegistrant(self.chairman_list)
        host_handler = MagicMock()
        host_handler.get_host_config_no_logging.return_value = \
            GetConfigResponse(hostConfig=self.host_config)
        common.services.register(Host.Iface, host_handler)
        self.request = RegisterHostRequest("foo", self.host_config)

        self.state_file = tempfile.mktemp()
        common.services.register(ServiceName.MODE,
                                 Mode(State(self.state_file)))
    def update_state(config, values, state, update_values):

        # Inner method to check values before adding to state
        def check_values(unchecked_values):

            # Sort current values
            result_values = []
            current_flag = ''

            # Check values with start|stop order
            for value in sorted(unchecked_values):
                if value[1] != current_flag:
                    result_values += [value]
                    current_flag = value[1]
                elif current_flag == 'start':
                    del result_values[-1]
                    result_values += [value]

            # Remove first stop value if it exists
            if len(result_values) > 0 and result_values[0][1] != 'start':
                del result_values[0]

            return result_values

        from common.state import State

        # Create new state if it's none or previous one was expired
        if state is None or state.is_worked_out():
            checked_values = check_values(values)
            if len(checked_values) > 0:
                return State(update_values(checked_values, []),
                             expired_time=config.property('expire.delay'))
            else:
                return None

        # Update values for current state unexpired yet
        else:
            checked_values = check_values(state.values() + values)
            return state.update(checked_values, update_values)
    def main(self):
        """
        Main which should be overwritten.
        """

        test_images = utils.read_hdf5(self.args.test_images_file)
        log('[Sampling] read %s' % self.args.test_images_file)

        if len(test_images.shape) < 4:
            test_images = numpy.expand_dims(test_images, axis=3)

        network_units = list(map(int, self.args.network_units.split(',')))
        self.decoder = models.LearnedDecoder(
            self.args.latent_space_size,
            resolution=(test_images.shape[3], test_images.shape[1],
                        test_images.shape[2]),
            architecture=self.args.network_architecture,
            start_channels=self.args.network_channels,
            activation=self.args.network_activation,
            batch_normalization=not self.args.network_no_batch_normalization,
            units=network_units)
        log(self.decoder)

        assert os.path.exists(self.args.decoder_file)
        state = State.load(self.args.decoder_file)
        log('[Sampling] loaded %s' % self.args.decoder_file)

        self.decoder.load_state_dict(state.model)
        log('[Sampling] loaded decoder')

        if self.args.use_gpu and not cuda.is_cuda(self.decoder):
            self.decoder = self.decoder.cuda()

        log('[Sampling] model needs %gMiB' %
            ((cuda.estimate_size(self.decoder)) / (1024 * 1024)))
        self.sample()
    def load_model_and_scheduler(self):
        """
        Load model.
        """

        params = {
            'lr': self.args.lr,
            'lr_decay': self.args.lr_decay,
            'lr_min': 0.0000001,
            'weight_decay': self.args.weight_decay,
        }

        log('[Training] using %d input channels' % self.train_images.shape[3])
        network_units = list(map(int, self.args.network_units.split(',')))
        self.model = models.Classifier(
            self.N_class,
            resolution=(self.train_images.shape[3], self.train_images.shape[1],
                        self.train_images.shape[2]),
            architecture=self.args.network_architecture,
            activation=self.args.network_activation,
            batch_normalization=not self.args.network_no_batch_normalization,
            start_channels=self.args.network_channels,
            dropout=self.args.network_dropout,
            units=network_units)

        self.epoch = 0
        if os.path.exists(self.args.state_file):
            state = State.load(self.args.state_file)
            log('[Training] loaded %s' % self.args.state_file)

            self.model.load_state_dict(state.model)

            # needs to be done before costructing optimizer.
            if self.args.use_gpu and not cuda.is_cuda(self.model):
                self.model = self.model.cuda()
                log('[Training] model is not CUDA')
            log('[Training] loaded model')

            optimizer = torch.optim.Adam(self.model.parameters(), params['lr'])
            optimizer.load_state_dict(state.optimizer)
            self.scheduler = ADAMScheduler(optimizer, **params)

            self.epoch = state.epoch + 1
            self.scheduler.update(self.epoch)

            assert os.path.exists(self.args.training_file) and os.path.exists(
                self.args.testing_file)
            self.train_statistics = utils.read_hdf5(self.args.training_file)
            log('[Training] read %s' % self.args.training_file)
            self.test_statistics = utils.read_hdf5(self.args.testing_file)
            log('[Training] read %s' % self.args.testing_file)

            if utils.display():
                self.plot()
        else:
            if self.args.use_gpu and not cuda.is_cuda(self.model):
                self.model = self.model.cuda()
                log('[Training] model is not CUDA')
            log('[Training] did not load model, using new one')

            self.scheduler = ADAMScheduler(self.model.parameters(), **params)
            self.scheduler.initialize()  # !

        log(self.model)
Beispiel #31
0
 def setUp(self):
     self.file_location = tempfile.mktemp()
     self.state = State(self.file_location)
Beispiel #32
0
 def setUp(self):
     self.file_location = tempfile.mktemp()
     self.state = State(self.file_location)
     self.tags = DatastoreTags(self.state)
     common.services.register(ServiceName.DATASTORE_TAGS, self.tags)
Beispiel #33
0
 def setUp(self):
     self.file_location = tempfile.mkstemp()[1]
     self.state = State(self.file_location)
     self.mode = Mode(self.state)