Beispiel #1
0
 def forward(self, equery, vmemory, ememory, mask, iteration=0):
     """Compute an attention over memory given the query."""
     # equery.shape == (..., E)
     # vmemory.shape == (..., Ms, M)
     # ememory.shape == (..., Ms, E)
     # mask.shape == (..., Ms)
     # Setup memory embedding
     eq = F.repeat(equery[..., None, :], vmemory.shape[-2],
                   -2)  # (..., Ms, E)
     # Compute content based attention
     merged = F.concat(
         [eq, ememory, eq * ememory,
          F.squared_difference(eq, ememory)], -1)  # (..., Ms, 4*E)
     inter = self.att_linear(merged, n_batch_axes=len(vmemory.shape) -
                             1)  # (..., Ms, E)
     inter = F.tanh(inter)  # (..., Ms, E)
     inter = F.dropout(inter, DROPOUT)  # (..., Ms, E)
     # Split into sentences
     lengths = np.sum(np.any((vmemory != 0), -1), -1)  # (...,)
     mems = [s[..., :l, :] for s, l in zip(F.separate(inter, 0), lengths)
             ]  # B x [(M1, E), (M2, E), ...]
     _, bimems = self.att_birnn(None,
                                mems)  # B x [(M1, 2*E), (M2, 2*E), ...]
     bimems = F.pad_sequence(bimems)  # (..., Ms, 2*E)
     att = self.att_score(bimems, n_batch_axes=len(vmemory.shape) -
                          1)  # (..., Ms, 1)
     att = F.squeeze(att, -1)  # (..., Ms)
     if mask is not None:
         att += mask * MINUS_INF  # (..., Ms)
     return att
def l2_loss(generated, truth):
    """
	:param generated: Image generated by the Generator at any scale
	:param truth: Corresponding ground truth image
	:return: L2 Loss between the images
	"""
    n, c, h, w = generated.shape
    return (F.sum(F.squared_difference(generated, truth)))
 def check_forward(self, x1_data, x2_data):
     x1 = chainer.Variable(x1_data)
     x2 = chainer.Variable(x2_data)
     y = functions.squared_difference(x1, x2)
     y_expect = (self.x1 - self.x2) ** 2
     testing.assert_allclose(y.data, y_expect, atol=1e-3)
     self.assertEqual(y.data.shape, self.in_shape)
     self.assertEqual(y.data.dtype, self.dtype)
Beispiel #4
0
 def update(self, state, target):
     state = np.identity(env.observation_space.n, dtype=np.float32)[state]
     value = self.linear(F.expand_dims(state, 0))
     value = F.reshape(value, (1, ))
     loss = F.squared_difference(value, target)
     self.linear.cleargrads()
     loss.backward()
     self.optimizer.update()
 def check_forward(self, x1_data, x2_data):
     x1 = chainer.Variable(x1_data)
     x2 = chainer.Variable(x2_data)
     y = functions.squared_difference(x1, x2)
     y_expect = (self.x1 - self.x2) ** 2
     testing.assert_allclose(y.data, y_expect, atol=1e-3)
     self.assertEqual(y.data.shape, self.in_shape)
     self.assertEqual(y.data.dtype, self.dtype)
Beispiel #6
0
    def __call__(self, x, x_length, ns, ns_length, label):
        """

        Args:
            x (numpy.ndarray or cupy.ndarray): sequences of vocabulary indices
                in shape (batchsize, tokens)
            x_length (numpy.ndarray or cupy.ndarray): number of tokens in each
                batch index of ``x``
            ns (numpy.ndarray or cupy.ndarray): Negative samples.
                sequences of vocabulary indices in shape (batchsize,
                n_negative_samples, tokens)
            ns_length (numpy.ndarray or cupy.ndarray): number of tokens in each
                negative sample in shape ``(batchsize, n_negative_samples)``
            label: Ignored

        Returns:
            chainer.Variable:

        """
        z = self.sent_emb(x, x_length)
        p = self.pred_topic(z)
        # reconstructed sentence embedding r: (batchsize, feature size)
        r = F.matmul(p, self.T)

        # Embed negative sampling
        bs, n_ns, _ = ns.shape
        ns = ns.reshape(bs * n_ns, -1)
        ns_length = ns_length.astype(np.float32).reshape(-1, 1)
        n = F.sum(self.sent_emb.embed(ns), axis=1) / ns_length
        if self.sent_emb.fix_embedding:
            n.unchain_backward()
        n = F.reshape(n, (bs, n_ns, -1))

        # Calculate contrasive max-margin loss
        # neg: (batchsize, n_ns)
        neg = F.sum(F.broadcast_to(F.reshape(r, (bs, 1, -1)), n.shape) * n,
                    axis=-1)
        pos = F.sum(r * z, axis=-1)
        pos = F.broadcast_to(F.reshape(pos, (bs, 1)), neg.shape)
        mask = chainer.Variable(self.xp.zeros(neg.shape, dtype=p.dtype))
        loss_pred = F.sum(F.maximum(1. - pos + neg, mask))
        reporter.report({'loss_pred': loss_pred}, self)

        t_norm = F.normalize(self.T, axis=1)
        loss_reg = self._orthogonality_penalty * F.sqrt(
            F.sum(
                F.squared_difference(
                    F.matmul(t_norm, t_norm, transb=True),
                    self.xp.eye(self.T.shape[0], dtype=np.float32))))
        reporter.report({'orthogonality_penalty': loss_reg}, self)
        loss = loss_pred + loss_reg
        reporter.report({'loss': loss}, self)
        return loss
Beispiel #7
0
 def forward(self, stories):
     """Compute the forward inference pass for given stories."""
     self.log = dict()
     # ---------------------------
     vctx, vq, va, supps = stories  # (B, R, P, C), (B, Q), (B,), (B, I)
     # Embed stories
     # ectx = F.embed_id(vctx, wordeye, ignore_label=0) # (B, R, P, C, V)
     # eq = F.embed_id(vq, wordeye, ignore_label=0) # (B, Q, V)
     ectx = self.embed(vctx)  # (B, R, P, C, V)
     eq = self.embed(vq)  # (B, Q, V)
     # ---------------------------
     # Embed predicates
     embedded_preds = seq_rnn_embed(vctx, ectx, self.pred_rnn,
                                    reverse=True)  # (B, R, P, E)
     vector_preds = vctx[
         ..., 0]  # (B, R, P) first character to check if pred is empty
     embedded_query = seq_rnn_embed(vq, eq, self.pred_rnn,
                                    reverse=True)  # (B, E)
     embedded_rules = embedded_preds[:, :, 0]  # (B, R, E) head of rule
     # ---------------------------
     # Perform iterative updates
     state = embedded_query  # (B, E)
     repeated_query = F.repeat(embedded_query[:, None], vctx.shape[1],
                               1)  # (B, R, E)
     rule_mask = np.all(vctx == 0, (2, 3))  # (B, R)
     for _ in range(supps.shape[-1]):
         # Compute attention over memory
         repeated_state = F.repeat(state[:, None], vctx.shape[1],
                                   1)  # (B, R, E)
         combined = F.concat([
             repeated_state, embedded_rules, repeated_query,
             F.squared_difference(repeated_state, embedded_rules),
             embedded_rules * repeated_state
         ], -1)  # (B, R, 5*E)
         att = F.tanh(self.att_dense1(combined,
                                      n_batch_axes=2))  # (B, R, E//2)
         att = self.att_dense2(att, n_batch_axes=2)  # (B, R, 1)
         att = F.squeeze(att, -1)  # (B, R)
         att += rule_mask * MINUS_INF  # (B, R)
         self.tolog('raw_att', att)
         att = F.softmax(att)  # (B, R)
         self.tolog('att', att)
         # Iterate state
         new_states = seq_rnn_embed(
             vector_preds,
             embedded_preds,
             self.unifier,
             initial_state=repeated_state)  # (B, R, E)
         # Update state
         # (B, R) x (B, R, E) -> (B, E)
         state = F.einsum('br,bre->be', att, new_states)  # (B, E)
     return self.out_linear(state)[:, 0]  # (B,)
 def instance_norm(self, x, gamma=None, beta=None):
     mean = F.mean(x, axis=-1)
     mean = F.mean(mean, axis=-1)
     mean = F.broadcast_to(mean[Ellipsis, None, None], x.shape)
     var = F.squared_difference(x, mean)
     std = F.sqrt(var + 1e-5)
     x_hat = (x - mean) / std
     if gamma is not None:
         gamma = F.broadcast_to(gamma[None, Ellipsis, None, None], x.shape)
         beta = F.broadcast_to(beta[None, Ellipsis, None, None], x.shape)
         return gamma * x_hat + beta
     else:
         return x_hat
Beispiel #9
0
 def update_state(self, oldstate, mem_att, vmemory, ememory, iteration=0):
     """Update state given old, attention and new possible states."""
     # oldstate.shape == (..., E)
     # mem_att.shape == (..., Ms)
     # vmemory.shape == (..., Ms, M)
     # ememory.shape == (..., Ms, E)
     ostate = F.repeat(oldstate[..., None, :], vmemory.shape[-2],
                       -2)  # (..., Ms, E)
     merged = F.concat([
         ostate, ememory, ostate * ememory,
         F.squared_difference(ostate, ememory)
     ], -1)  # (..., Ms, 4*E)
     mem_inter = self.state_linear(merged,
                                   n_batch_axes=len(merged.shape) -
                                   1)  # (..., Ms, E)
     mem_inter = F.tanh(mem_inter)  # (..., E)
     # (..., Ms) x (..., Ms, E) -> (..., E)
     new_state = F.einsum("...i,...ij->...j", mem_att,
                          mem_inter)  # (..., E)
     return new_state
Beispiel #10
0
    def save_sample_images(self, epoch, batch):

        xp = cuda.cupy
        with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):

            images = self.sample_images
            images_att = self.sample_images_att
            cuda.get_device(GPU.main_gpu).use()
            objs_one_hot = cuda.to_gpu(self.sample_objs_one_hot, GPU.main_gpu)
            descs_one_hot = cuda.to_gpu(self.sample_descs_one_hot, GPU.main_gpu)

            z, m, v = self.enc_models[0](Variable(cuda.to_gpu(images, GPU.main_gpu)), Variable(objs_one_hot), Variable(descs_one_hot), train=False)
            data_att = self.gen_models[0](z, Variable(objs_one_hot), Variable(descs_one_hot), train=False).data

            test_rec_loss = F.squared_difference(data_att[:, :3], xp.asarray(images_att))
            test_rec_loss = float(F.sum(test_rec_loss).data) / self.normer

            print '\ntest error on the random number of images: ' + str(test_rec_loss)
            self.save_image(cuda.to_cpu(data_att[:, :3]), 'sample/{0:03d}_{1:07d}_att.png'.format(epoch, batch))
            self.save_image(cuda.to_cpu(data_att[:, 3:]), 'sample/{0:03d}_{1:07d}_whole.png'.format(epoch, batch))
            if batch == 0:
                self.save_image(images, 'sample/org_pic.png')
                self.save_image(images_att, 'sample/org_att.png')
Beispiel #11
0
import numpy as np
from chainer import optimizers, functions


def sinDatas(epoch, batch=10):
    out = []
    for i in range(epoch * batch):
        if i == 0:
            out.append(np.random.rand(25))
        elif (i % 10 == 0):
            yield np.array(out, dtype=np.float32)
            out = []
        else:
            out.append(np.random.rand(25))


if __name__ == '__main__':
    DNN = TestDNN()
    opt = optimizers.Adam().setup(DNN)
    epoch = 100
    datas = sinDatas(epoch)

    for i in datas:
        x = i
        y = np.sin(x)

        Y = DNN.forward(x)
        loss = np.sum(functions.squared_difference(Y, y))
        print("loss = {}".format(loss))
        loss.backward()
        opt.update()
Beispiel #12
0
 def forward(self, inputs, device):
     x1, x2 = inputs
     return functions.squared_difference(x1, x2),
 def forward(self, inputs, device):
     x1, x2 = inputs
     return functions.squared_difference(x1, x2),
def train(enc,
          gen,
          dis,
          optimizer_enc,
          optimizer_gen,
          optimizer_dis,
          epoch_num,
          out_image_dir=None):
    z_out_image = Variable(
        xp.random.uniform(-1, 1, (out_image_row_num * out_image_col_num,
                                  latent_size)).astype(np.float32))
    for epoch in xrange(1, epoch_num + 1):
        start_time = time.time()
        sum_loss_enc = sum_loss_gen = sum_loss_dis = sum_loss_rnn = 0
        np.random.shuffle(train_indices)
        for i in xrange(0, x_size - max_seq_length * BATCH_SIZE, BATCH_SIZE):
            batch_start_time = time.time()
            loss_enc, loss_gen, loss_dis, loss_rec, loss_rnn = train_one(
                enc, gen, dis, rnn_model, optimizer_enc, optimizer_gen,
                optimizer_dis, i)
            sum_loss_enc += loss_enc * BATCH_SIZE
            sum_loss_gen += loss_gen * BATCH_SIZE
            sum_loss_dis += loss_dis * BATCH_SIZE
            sum_loss_rnn += loss_rnn * BATCH_SIZE
            if i % image_save_interval == 0:
                with chainer.using_config('train',
                                          False), chainer.using_config(
                                              'enable_backprop', False):
                    print ''
                    print '{} {} {} {}'.format(
                        sum_loss_enc / (image_save_interval),
                        sum_loss_gen / (image_save_interval),
                        sum_loss_dis / (image_save_interval),
                        sum_loss_rnn / (image_save_interval))
                    if out_image_dir != None:
                        cuda.get_device(main_gpu).use()
                        z, m, v, _ = enc[0](Variable(
                            cuda.to_gpu(test_batch, main_gpu)),
                                            train=False)
                        z = m
                        data = gen[0](z, train=False).data
                        test_rec_loss = F.squared_difference(
                            data, xp.asarray(test_batch))
                        test_rec_loss = float(
                            F.sum(test_rec_loss).data) / (normer)
                        image = ((cuda.to_cpu(data) + 1) * 128).clip(
                            0, 255).astype(np.uint8)
                        image = image[:out_image_row_num * out_image_col_num]
                        image = image.reshape(
                            (out_image_row_num, out_image_col_num, 3,
                             image_size, image_size)).transpose(
                                 (0, 3, 1, 4, 2)).reshape(
                                     (out_image_row_num * image_size,
                                      out_image_col_num * image_size, 3))
                        Image.fromarray(image).save(
                            '{0}/{1:03d}_{2:07d}.png'.format(
                                out_image_dir, epoch, i))
                        if i == 0:
                            org_image = ((test_batch + 1) * 128).clip(
                                0, 255).astype(np.uint8)
                            org_image = org_image[:out_image_row_num *
                                                  out_image_col_num]
                            org_image = org_image.reshape(
                                (out_image_row_num, out_image_col_num, 3,
                                 image_size, image_size)).transpose(
                                     (0, 3, 1, 4, 2)).reshape(
                                         (out_image_row_num * image_size,
                                          out_image_col_num * image_size, 3))
                            Image.fromarray(org_image).save(
                                '{0}/org.png'.format(out_image_dir, epoch, i))
                        sum_loss_enc = sum_loss_gen = sum_loss_dis = sum_loss_rnn = 0
            if i % model_save_interval == 0:
                serializers.save_hdf5('{0}enc.model'.format(args.output),
                                      enc[0])
                serializers.save_hdf5('{0}enc.state'.format(args.output),
                                      optimizer_enc)
                serializers.save_hdf5('{0}gen.model'.format(args.output),
                                      gen[0])
                serializers.save_hdf5('{0}gen.state'.format(args.output),
                                      optimizer_gen)
                if cost_mode == 'BEGAN':
                    serializers.save_hdf5(
                        '{0}enc_dis.model'.format(args.output), enc_dis_model)
                    serializers.save_hdf5(
                        '{0}enc_dis.state'.format(args.output),
                        optimizer_enc_dis)
                    serializers.save_hdf5(
                        '{0}gen_dis.model'.format(args.output), gen_dis_model)
                    serializers.save_hdf5(
                        '{0}gen_dis.state'.format(args.output),
                        optimizer_gen_dis)

                serializers.save_hdf5('{0}dis.model'.format(args.output),
                                      dis[0])
                serializers.save_hdf5('{0}dis.state'.format(args.output),
                                      optimizer_dis)
                serializers.save_hdf5('{0}rnn.model'.format(args.output),
                                      rnn_model)
                serializers.save_hdf5('{0}rnn.state'.format(args.output),
                                      optimizer_rnn)
            sys.stdout.write(
                '\r' + str(i / BATCH_SIZE) + '/' + str(x_size / BATCH_SIZE) +
                ' time: {0:0.2f} errors: {1:0.4f} {2:0.4f} {3:0.8f} {4:0.4f} {5:0.4f} {6:0.4f}'
                .format(time.time() - batch_start_time, loss_enc, loss_gen,
                        loss_dis, loss_rnn, loss_rec, test_rec_loss))
            sys.stdout.flush()
        print '-----------------------------------------'
        print 'epoch: {} done'.format(epoch)
        print 'time: {}'.format(time.time() - start_time)
Beispiel #15
0
    def __call__(self, *args):
        xp = chainer.cuda.get_array_module(*args)
        in_x, in_d, in_i, in_m, in_n, in_r, in_p = args
        # x: (n, c, m, h, w)
        # d: (n, m, 3)
        # i: (n, c, m)
        # m: (n, 1, h, w)
        # n: (n, 3, h, w)
        # r: (n, 4) (ymin, xmin, ymax, xmax)
        r = in_r[0] if in_r.shape[0] == 1 else xp.concatenate(
            (xp.min(in_r[:2], 0), xp.max(in_r[2:], 0)), 1)
        in_x = in_x[..., r[0]:r[2], r[1]:r[3]]
        in_m = in_m[..., r[0]:r[2], r[1]:r[3]]
        in_n = in_n[..., r[0]:r[2], r[1]:r[3]]
        in_p = in_p[..., r[0]:r[2], r[1]:r[3]]

        b, color, m, h, w = in_x.shape
        x_mean, x_std, x_amp = masked_mean_stddev(in_x, in_m[:, None],
                                                  (1, 2, 3, 4))
        in_x /= (2.0 * x_amp)

        x = in_x
        x = self.xp.reshape(x, (b, color * m, h, w))
        if self.opt.ps_mask:
            x = self.xp.concatenate((x, in_m.astype(in_x.dtype)), 1)
        x, n_map = self.psnet(x)

        # Swap axes of color channels and measurements
        # x: (n, m, c, h, w)
        # d: (n, m, 3)
        # i: (n, m, c)
        # m: (n, 1, h, w)
        # n: (n, 3, h, w)
        in_i = in_i.swapaxes(1, 2)
        in_x = in_x.swapaxes(1, 2)

        # Subsample reconstruction images
        if chainer.config.train and self.opt.ir_num > 0:
            r = self.opt.ir_num
            if len(self.reco_inds) < r:
                self.reco_inds.extend(np.random.permutation(m).tolist())
            s = self.reco_inds[0:r]
            self.reco_inds = self.reco_inds[r:]
            s.sort()
            in_i = self.xp.take(in_i, s, 1)
            in_d = self.xp.take(in_d, s, 1)
            in_x = self.xp.take(in_x, s, 1)

        z = self.irnet(in_x, in_i, in_d, x, n_map)
        result = n_map, z, in_x, in_n, in_m

        if hasattr(self, 'prior') and self.prior(self.iterations) != 0:
            weight = self.prior(self.iterations)
            l2 = F.squared_difference(n_map, in_p)
            l2 = F.where(self.xp.broadcast_to(in_m, l2.shape), l2,
                         self.xp.zeros_like(l2))
            l2 = F.mean(l2, axis=(1, 2, 3)) * (
                np.abs(weight) / in_m.mean(axis=(1, 2, 3), dtype=l2.dtype))
            l2 *= x_mean.ravel()
            l2 = F.mean(l2)
            result += (l2, )

        if chainer.config.train:
            self.iterations += 1

        return result
Beispiel #16
0
def sum_squared_error(x0, x1):
    return F.sum(F.squared_difference(x0, x1))