Ejemplo n.º 1
0
 def batches(self):
     self._shuffle()
     for bidx in range(self.n_batches):
         s_idxs = compute_batch_idxs(bidx * self.half_batch_size,
                                     self.half_batch_size, self.n_s_samples)
         t_idxs = compute_batch_idxs(bidx * self.half_batch_size,
                                     self.half_batch_size, self.n_t_samples)
         yield dict(xs=self.xs[s_idxs],
                    ys=self.ys[s_idxs],
                    ws=self.ws[s_idxs],
                    xt=self.xt[t_idxs],
                    yt=self.yt[t_idxs],
                    wt=self.wt[t_idxs])
Ejemplo n.º 2
0
    def predict(self, inputs, lengths):
        sess = tf.get_default_session()
        n_samples = len(inputs)
        n_batches = utils.compute_n_batches(n_samples, self.batch_size)
        probs = np.zeros((len(inputs), self.max_len, self.output_dim))
        for bidx in range(n_batches):
            idxs = utils.compute_batch_idxs(bidx * self.batch_size, self.batch_size, n_samples)

            probs[idxs] = sess.run(self.probs, feed_dict={
                self.inputs:inputs[idxs],
                self.lengths:lengths[idxs],
                self.dropout_keep_prop_ph: 1.
            })

        preds = np.argmax(probs, axis=-1)
        return probs, preds
Ejemplo n.º 3
0
    def generate(self, z=None, c=None, batch_size=100):
        # setup
        sess = tf.get_default_session()

        # sample noise if not provided
        if z is None and c is None:
            z, c, _ = self._sample_noise(batch_size)
        elif z is None:
            batch_size = len(c)
            z, _, _ = self._sample_noise(batch_size)
        else:
            batch_size = len(c)

        # at this point, z dictates number of samples
        n_samples = len(z)
        n_batches = compute_n_batches(n_samples, batch_size)

        # allocate return containers
        gx = np.zeros((n_samples, self.input_dim))
        scores = np.zeros((n_samples, 1))

        # formulate outputs
        outputs = [self.gx, self.gen_scores]

        # run the batches
        for bidx in range(n_batches):
            idxs = compute_batch_idxs(bidx * batch_size, batch_size, n_samples)
            feed = {self.z: z[idxs], self.c: c[idxs]}
            fetched = sess.run(outputs, feed_dict=feed)

            # unpack
            gx[idxs] = fetched[0]
            scores[idxs] = fetched[1]

        # return the relevant info
        return dict(gx=gx, scores=scores)
Ejemplo n.º 4
0
    def predict(self, x, tgt=True, batch_size=100):
        # setup
        sess = tf.get_default_session()
        n_samples = len(x)
        n_batches = compute_n_batches(n_samples, batch_size)
        probs = np.zeros((n_samples, self.output_dim))

        # decide between src or tgt probs
        outputs_list = [self.tgt_task_probs] if tgt else [self.src_task_probs]
        x_ph = self.xt if tgt else self.xs

        # compute probs
        for bidx in range(n_batches):
            idxs = compute_batch_idxs(bidx * batch_size,
                                      batch_size,
                                      n_samples,
                                      fill='none')
            probs[idxs] = sess.run(outputs_list,
                                   feed_dict={
                                       x_ph: x[idxs],
                                       self.dropout_keep_prob_ph: 1.
                                   })

        return probs
Ejemplo n.º 5
0
 def batches(self):
     self._shuffle()
     for bidx in range(self.n_batches):
         idxs = compute_batch_idxs(bidx * self.batch_size, self.batch_size,
                                   self.n_samples)
         yield dict(x=self.x[idxs])
Ejemplo n.º 6
0
    def train(self,
              data,
              n_epochs=100,
              batch_size=100,
              writer=None,
              val_writer=None,
              stop_early=False):
        sess = tf.get_default_session()

        if self.batch_size is not None:
            batch_size = self.batch_size

        n_samples = len(data['x_train'])
        n_batches = utils.compute_n_batches(n_samples, batch_size)
        n_val_samples = len(data['x_val'])
        n_val_batches = utils.compute_n_batches(n_val_samples, batch_size)

        last_val_losses = collections.deque([np.inf] * 2)

        for epoch in range(n_epochs):

            # shuffle train set
            idxs = np.random.permutation(len(data['x_train']))
            data['x_train'] = data['x_train'][idxs]
            data['y_train'] = data['y_train'][idxs]

            # train
            total_loss = 0
            for bidx in range(n_batches):
                idxs = utils.compute_batch_idxs(bidx * batch_size, batch_size,
                                                n_samples)
                feed_dict = {
                    self.inputs: data['x_train'][idxs],
                    self.targets: data['y_train'][idxs],
                }
                outputs_list = [
                    self.loss, self.summary_op, self.global_step, self.train_op
                ]
                loss, summary, step, _ = sess.run(outputs_list,
                                                  feed_dict=feed_dict)
                total_loss += loss
                if writer is not None:
                    writer.add_summary(summary, step)
                sys.stdout.write(
                    '\repoch: {} / {} batch: {} / {} loss: {}'.format(
                        epoch + 1, n_epochs, bidx + 1, n_batches,
                        total_loss / (bidx + 1)))
            print('\n')

            # val
            total_loss = 0
            for bidx in range(n_val_batches):
                s = bidx * batch_size
                e = s + batch_size
                idxs = utils.compute_batch_idxs(bidx * batch_size, batch_size,
                                                n_val_samples)
                feed_dict = {
                    self.inputs: data['x_val'][idxs],
                    self.targets: data['y_val'][idxs],
                    self.dropout_keep_prop_ph: 1.
                }
                outputs_list = [self.loss, self.summary_op, self.global_step]
                loss, summary, step = sess.run(outputs_list,
                                               feed_dict=feed_dict)
                total_loss += loss
                if val_writer is not None:
                    val_writer.add_summary(summary, step)
                sys.stdout.write(
                    '\rval epoch: {} / {} batch: {} / {} loss: {}'.format(
                        epoch + 1, n_epochs, bidx + 1, n_val_batches,
                        total_loss / (bidx + 1)))
            print('\n')
            if stop_early:
                if all(total_loss > v for v in last_val_losses):
                    break
                last_val_losses.popleft()
                last_val_losses.append(total_loss)
Ejemplo n.º 7
0
    def train(
            self, 
            data, 
            n_epochs=100, 
            writer=None, 
            val_writer=None):
        sess = tf.get_default_session()
        
        n_samples = len(data['train_x'])
        n_batches = utils.compute_n_batches(n_samples, self.batch_size)
        n_val_samples = len(data['val_x'])
        n_val_batches = utils.compute_n_batches(n_val_samples, self.batch_size)
        
        for epoch in range(n_epochs):

            # shuffle train set
            idxs = np.random.permutation(len(data['train_x']))
            data['train_x'] = data['train_x'][idxs]
            data['train_y'] = data['train_y'][idxs]
            data['train_lengths'] = data['train_lengths'][idxs]

            # train
            total_loss = 0
            for bidx in range(n_batches):
                idxs = utils.compute_batch_idxs(bidx * self.batch_size, self.batch_size, n_samples)
                feed_dict = {
                    self.inputs:data['train_x'][idxs],
                    self.targets:data['train_y'][idxs],
                    self.lengths:data['train_lengths'][idxs]
                }
                outputs_list = [self.loss, self.summary_op, self.global_step, self.train_op]
                loss, summary, step, _ = sess.run(outputs_list, feed_dict=feed_dict)
                total_loss += loss
                writer.add_summary(summary, step)
                sys.stdout.write('\repoch: {} / {} batch: {} / {} loss: {}'.format(
                    epoch+1, n_epochs, bidx+1, n_batches, 
                    total_loss / (self.batch_size * (bidx+1))))
            self.validate(data['train_x'], data['train_y'], data['train_lengths'], writer, epoch)
            print('\n')

            # val
            total_loss = 0
            for bidx in range(n_val_batches):
                s = bidx * self.batch_size
                e = s + self.batch_size
                idxs = utils.compute_batch_idxs(bidx * self.batch_size, self.batch_size, n_val_samples)
                feed_dict = {
                    self.inputs:data['val_x'][idxs],
                    self.targets:data['val_y'][idxs],
                    self.lengths:data['val_lengths'][idxs],
                    self.dropout_keep_prop_ph: 1.
                }
                outputs_list = [self.loss, self.summary_op, self.global_step]
                loss, summary, step = sess.run(outputs_list, feed_dict=feed_dict)
                total_loss += loss
                val_writer.add_summary(summary, step)
                sys.stdout.write('\rval epoch: {} / {} batch: {} / {} loss: {}'.format(
                    epoch+1, n_epochs, bidx+1, n_val_batches, 
                    total_loss / (self.batch_size * (bidx+1))))
            self.validate(data['val_x'], data['val_y'], data['val_lengths'], val_writer, epoch)
            print('\n')