Example #1
0
    def optimize(self, inputs, extra_inputs=None, callback=None):
        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            gc.collect()
            for batch in dataset.iterate(update=True):
                sess.run(self._train_op,
                         dict(list(zip(self._input_vars, batch))))
                if self._verbose:
                    progbar.update(len(batch[0]))
            gc.collect()

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            for batch in dataset.iterate(update=True):
                if self._init_train_op is not None:
                    sess.run(self._init_train_op, dict(list(zip(self._input_vars, batch))))
                    self._init_train_op = None  # only use it once
                else:
                    sess.run(self._train_op, dict(list(zip(self._input_vars, batch))))

                if self._verbose:
                    progbar.update(len(batch[0]))

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
    def optimize_gen(self,
                     inputs,
                     extra_inputs=None,
                     callback=None,
                     yield_itr=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs
                               #, randomized=self._randomized
                               )

        itr = 0
        for epoch in pyprind.prog_bar(list(range(self._max_epochs))):
            for batch in dataset.iterate(update=True):
                f_opt(*batch)
                if yield_itr is not None and (itr % (yield_itr + 1)) == 0:
                    yield
                itr += 1

            new_loss = self.loss(inputs, extra_inputs)
            if self._verbose:
                logger.log("Epoch %d, loss %s" % (epoch, new_loss))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
    def optimize_gen(self, inputs, extra_inputs=None, callback=None, yield_itr=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(
            inputs, self._batch_size,
            extra_inputs=extra_inputs
            #, randomized=self._randomized
        )

        itr = 0
        for epoch in pyprind.prog_bar(list(range(self._max_epochs))):
            for batch in dataset.iterate(update=True):
                f_opt(*batch)
                if yield_itr is not None and (itr % (yield_itr+1)) == 0:
                    yield
                itr += 1

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))
            if self._verbose:
                logger.log("Epoch %d, loss %s" % (epoch, new_loss))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in xrange(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % epoch)
            for batch in dataset.iterate(update=True):
                f_opt(*batch)

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
Example #6
0
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % epoch)
            for batch in dataset.iterate(update=True):
                f_opt(*batch)

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 callback=None,
                 val_inputs=[None],
                 val_extra_inputs=tuple([None])):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        assert len(inputs) == 1

        dataset_size, _ = inputs[0].shape
        f_loss = self._opt_fun["f_loss"]

        # Plot individual costs from complexity / likelihood terms.
        try:
            use_c_loss = True
            c_loss = self._opt_fun["c_loss"]
        except KeyError:
            use_c_loss = False
        try:
            use_l_loss = True
            l_loss = self._opt_fun["l_loss"]
        except KeyError:
            use_l_loss = False

        if extra_inputs is None:
            extra_inputs = tuple()

        #last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        train_dataset = BatchDataset(inputs,
                                     self._batch_size,
                                     extra_inputs=extra_inputs)
        if not all([vi is None for vi in val_inputs]):
            val_dataset = BatchDataset(val_inputs,
                                       self._batch_size,
                                       extra_inputs=val_extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            train_losses = []
            train_c_losses, train_l_losses = [], []
            # batch is a (matrix X, matrix Y) tuple
            for t, batch in enumerate(train_dataset.iterate(update=True)):
                sess.run(self._train_op,
                         dict(list(zip(self._input_vars, batch))))
                train_losses.append(f_loss(*batch))
                train_c_losses.append(c_loss(*batch))
                train_l_losses.append(l_loss(*batch))

                if self._verbose:
                    progbar.update(len(batch[0]))

            train_loss = np.mean(train_losses)
            train_c_loss = np.mean(train_c_losses)
            train_l_loss = np.mean(train_l_losses)

            val_losses = []
            if not all([vi is None for vi in val_inputs]):
                for t, batch in enumerate(val_dataset.iterate(update=True)):
                    val_losses.append(f_loss(*batch))

                val_loss = np.mean(val_losses)

            for interval, op in self._update_priors_ops:
                if interval != 0 and epoch % interval == 0:
                    sess.run(op)

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, train_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=train_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if use_c_loss:
                    callback_args['c_loss'] = train_c_loss
                if use_l_loss:
                    callback_args['l_loss'] = train_l_loss

                if val_loss is not None:
                    callback_args.update({'val_loss': val_loss})
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)