示例#1
0
def lossfun_one_batch(model, gen_model, dis_model, opt, fea_opt, opt_gen ,opt_dis, params, batch, epoch = 100):
    # the first half of a batch are the anchors and the latters
    # are the positive examples corresponding to each anchor
    lambda1 = 1.0
    lambda2 = 1.0
    if params.loss == "angular":
        x_data, c_data = batch
        x_data = model.xp.asarray(x_data)
    
        y = model(x_data)
        y_a, y_p = F.split_axis(y, 2, axis=0)
        return angular_mc_loss_m(y_a, y_p, params.tradeoff,params.alpha)
    elif params.loss == "triplet":
        x_data, c_data = batch
        x_data = model.xp.asarray(x_data)
        batch = model(x_data)
        batchsize = len(batch)
        a, p, n = F.split_axis(batch, 3, axis=0)
        t_loss = F_tloss(a, p, n, params.alpha)
        batch_concat = F.concat([a, p, n], axis = 1)

        fake = gen_model(batch_concat)
        batch_fake = F.concat([a, p, fake], axis=0)
        embedding_fake = dis_model(batch_fake)

        loss_hard = l2_hard(batch_fake,batch)
        loss_reg = l2_norm(batch_fake,batch)
        loss_adv = adv_loss(embedding_fake)
        loss_gen = loss_hard + lambda1*loss_reg + lambda2*loss_adv
        loss_m = triplet_loss(embedding_fake)

        
        if epoch < 20:
            t_loss.backward()
            fea_opt.update()
        else:            
            loss_gen.backward()
            loss_m.backward()
            opt.update()
            opt_gen.update()
            opt_dis.update()
        model.cleargrads()
        gen_model.cleargrads()
        dis_model.cleargrads()

        chainer.reporter.report({'loss_gen': loss_gen})
        chainer.reporter.report({'loss_dis': loss_m})
        return loss_gen, loss_m
示例#2
0
    def update(self, s, i):
        """Update decoder state

        Args:
            s (any): Current (hidden, cell) states.  If ``None`` is specified 
                     zero-vector is used.
            i (int): input label.
        Return:
            (~chainer.Variable) updated decoder state
        """
        x = torch.cat((self.embed(i), self.hx), dim=1)
        if s is not None and len(s[0]) == self.n_layers * 2:
            s = list(s)
            for m in (0, 1):
                ss = []
                for n in six.moves.range(0, len(s[m]), 2):
                    ss.append(F.concat((s[m][n], s[m][n + 1]), axis=1))
                s[m] = F.stack(ss, axis=0)

        if len(i) != 0:
            xs = torch.unsqueeze(x, 0)
        else:
            xs = [x]

        if s is not None:
            dy, (hy, cy) = self.lstm(xs, (s[0], s[1]))
        else:
            dy, (hy, cy) = self.lstm(xs)

        return hy, cy, dy
示例#3
0
def cal_vat_v(dnn, x):

    xi = 10
    eps = 1

    (m, dim) = x.shape

    d = np.random.randn(m, dim)
    d = trans_unit_v(d)

    r = Variable(np.array(xi * d, dtype=np.float32))
    x_in = F.concat((x, x + r), axis=0)
    y_out = dnn(x_in)
    y_out_split = F.split_axis(y_out, 2, axis=0)
    z0 = y_out_split[0]
    z1 = y_out_split[1]

    tgt = z0.data

    loss_apro = kldiv1(tgt, z1)

    dnn.cleargrads()
    loss_apro.backward()

    g = r.grad
    r_adv = (eps) * trans_unit_v(g)

    return r_adv
示例#4
0
    def forward(self, x):
        output_slices = [x]
        h, w = x.shape[2:]

        for module, pool_size in zip(self.path_module_list, self.pool_sizes):
            out = F.avg_pool2d(x, pool_size, 1, 0)
            out = module(out)
            out = F.upsample(out, size=(h,w), mode='bilinear')
            output_slices.append(out)

        return F.concat(output_slices, axis=1)
示例#5
0
 # compute output by deep neural net
 y_batch = net(x_batch)
 y_batch_separable = F.split_axis( y_batch, [m1, m1+m2], axis=0 )
 y1 = y_batch_separable[0]
 y2 = y_batch_separable[1]
 y3 = y_batch_separable[2]
 
 # define VAT loss 
 target_p = np.r_[y1.data, y2.data]
 loss1 = kldiv1(target_p, y3)
 
 # define pseudo cross entropy loss
 loss2 = kldiv2(y1, tl_train_batch)
 
 # define entropy loss
 loss3 = ent(F.concat((y1,y2),axis=0))
 
 # define total loss function (which should be optimized)
 loss = loss1 + lambda1*loss2 + lambda2*loss3 
 
 # compute gradient
 optimizer.zero_grad()
 loss.backward()
 
 # update trainable parameters in deep neural net
 optimizer.step()
 
 # update counter
 iteration += 1
 count +=1
 
示例#6
0
    def __call__(self, x):
        heatmaps = []
        pafs = []

        h = F.relu(self.conv1_1(x))
        h = F.relu(self.conv1_2(h))
        h = F.max_pool2d(h, 2, 2)
        h = F.relu(self.conv2_1(h))
        h = F.relu(self.conv2_2(h))
        h = F.max_pool2d(h, 2, 2)
        h = F.relu(self.conv3_1(h))
        h = F.relu(self.conv3_2(h))
        h = F.relu(self.conv3_3(h))
        h = F.relu(self.conv3_4(h))
        h = F.max_pool2d(h, 2, 2)
        h = F.relu(self.conv4_1(h))
        h = F.relu(self.conv4_2(h))
        h = F.relu(self.conv4_3_CPM(h))
        h = F.relu(self.conv4_4_CPM(h))
        feature_map = h

        # stage1
        h1 = F.relu(self.conv5_1_CPM_L1(feature_map))  # branch1
        h1 = F.relu(self.conv5_2_CPM_L1(h1))
        h1 = F.relu(self.conv5_3_CPM_L1(h1))
        h1 = F.relu(self.conv5_4_CPM_L1(h1))
        h1 = self.conv5_5_CPM_L1(h1)
        h2 = F.relu(self.conv5_1_CPM_L2(feature_map))  # branch2
        h2 = F.relu(self.conv5_2_CPM_L2(h2))
        h2 = F.relu(self.conv5_3_CPM_L2(h2))
        h2 = F.relu(self.conv5_4_CPM_L2(h2))
        h2 = self.conv5_5_CPM_L2(h2)
        pafs.append(h1)
        heatmaps.append(h2)

        # stage2
        h = F.concat((h1, h2, feature_map), axis=1)  # channel concat
        h1 = F.relu(self.Mconv1_stage2_L1(h))  # branch1
        h1 = F.relu(self.Mconv2_stage2_L1(h1))
        h1 = F.relu(self.Mconv3_stage2_L1(h1))
        h1 = F.relu(self.Mconv4_stage2_L1(h1))
        h1 = F.relu(self.Mconv5_stage2_L1(h1))
        h1 = F.relu(self.Mconv6_stage2_L1(h1))
        h1 = self.Mconv7_stage2_L1(h1)
        h2 = F.relu(self.Mconv1_stage2_L2(h))  # branch2
        h2 = F.relu(self.Mconv2_stage2_L2(h2))
        h2 = F.relu(self.Mconv3_stage2_L2(h2))
        h2 = F.relu(self.Mconv4_stage2_L2(h2))
        h2 = F.relu(self.Mconv5_stage2_L2(h2))
        h2 = F.relu(self.Mconv6_stage2_L2(h2))
        h2 = self.Mconv7_stage2_L2(h2)
        pafs.append(h1)
        heatmaps.append(h2)

        # stage3
        h = F.concat((h1, h2, feature_map), axis=1)  # channel concat
        h1 = F.relu(self.Mconv1_stage3_L1(h))  # branch1
        h1 = F.relu(self.Mconv2_stage3_L1(h1))
        h1 = F.relu(self.Mconv3_stage3_L1(h1))
        h1 = F.relu(self.Mconv4_stage3_L1(h1))
        h1 = F.relu(self.Mconv5_stage3_L1(h1))
        h1 = F.relu(self.Mconv6_stage3_L1(h1))
        h1 = self.Mconv7_stage3_L1(h1)
        h2 = F.relu(self.Mconv1_stage3_L2(h))  # branch2
        h2 = F.relu(self.Mconv2_stage3_L2(h2))
        h2 = F.relu(self.Mconv3_stage3_L2(h2))
        h2 = F.relu(self.Mconv4_stage3_L2(h2))
        h2 = F.relu(self.Mconv5_stage3_L2(h2))
        h2 = F.relu(self.Mconv6_stage3_L2(h2))
        h2 = self.Mconv7_stage3_L2(h2)
        pafs.append(h1)
        heatmaps.append(h2)

        # stage4
        h = F.concat((h1, h2, feature_map), axis=1)  # channel concat
        h1 = F.relu(self.Mconv1_stage4_L1(h))  # branch1
        h1 = F.relu(self.Mconv2_stage4_L1(h1))
        h1 = F.relu(self.Mconv3_stage4_L1(h1))
        h1 = F.relu(self.Mconv4_stage4_L1(h1))
        h1 = F.relu(self.Mconv5_stage4_L1(h1))
        h1 = F.relu(self.Mconv6_stage4_L1(h1))
        h1 = self.Mconv7_stage4_L1(h1)
        h2 = F.relu(self.Mconv1_stage4_L2(h))  # branch2
        h2 = F.relu(self.Mconv2_stage4_L2(h2))
        h2 = F.relu(self.Mconv3_stage4_L2(h2))
        h2 = F.relu(self.Mconv4_stage4_L2(h2))
        h2 = F.relu(self.Mconv5_stage4_L2(h2))
        h2 = F.relu(self.Mconv6_stage4_L2(h2))
        h2 = self.Mconv7_stage4_L2(h2)
        pafs.append(h1)
        heatmaps.append(h2)

        # stage5
        h = F.concat((h1, h2, feature_map), axis=1)  # channel concat
        h1 = F.relu(self.Mconv1_stage5_L1(h))  # branch1
        h1 = F.relu(self.Mconv2_stage5_L1(h1))
        h1 = F.relu(self.Mconv3_stage5_L1(h1))
        h1 = F.relu(self.Mconv4_stage5_L1(h1))
        h1 = F.relu(self.Mconv5_stage5_L1(h1))
        h1 = F.relu(self.Mconv6_stage5_L1(h1))
        h1 = self.Mconv7_stage5_L1(h1)
        h2 = F.relu(self.Mconv1_stage5_L2(h))  # branch2
        h2 = F.relu(self.Mconv2_stage5_L2(h2))
        h2 = F.relu(self.Mconv3_stage5_L2(h2))
        h2 = F.relu(self.Mconv4_stage5_L2(h2))
        h2 = F.relu(self.Mconv5_stage5_L2(h2))
        h2 = F.relu(self.Mconv6_stage5_L2(h2))
        h2 = self.Mconv7_stage5_L2(h2)
        pafs.append(h1)
        heatmaps.append(h2)

        # stage6
        h = F.concat((h1, h2, feature_map), axis=1)  # channel concat
        h1 = F.relu(self.Mconv1_stage6_L1(h))  # branch1
        h1 = F.relu(self.Mconv2_stage6_L1(h1))
        h1 = F.relu(self.Mconv3_stage6_L1(h1))
        h1 = F.relu(self.Mconv4_stage6_L1(h1))
        h1 = F.relu(self.Mconv5_stage6_L1(h1))
        h1 = F.relu(self.Mconv6_stage6_L1(h1))
        h1 = self.Mconv7_stage6_L1(h1)
        h2 = F.relu(self.Mconv1_stage6_L2(h))  # branch2
        h2 = F.relu(self.Mconv2_stage6_L2(h2))
        h2 = F.relu(self.Mconv3_stage6_L2(h2))
        h2 = F.relu(self.Mconv4_stage6_L2(h2))
        h2 = F.relu(self.Mconv5_stage6_L2(h2))
        h2 = F.relu(self.Mconv6_stage6_L2(h2))
        h2 = self.Mconv7_stage6_L2(h2)
        pafs.append(h1)
        heatmaps.append(h2)

        return pafs, heatmaps