Ejemplo n.º 1
def train(lamda=100, lr_decay=0.2, period=50, ckpt='.', viz=False):
    image_pool = ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    #fig = plt.figure()
    for epoch in range(num_epochs):
        epoch_tic = time.time()
        btic = time.time()
        for iter, batch in enumerate(train_data):
            real_in, real_out = batch.data[0].as_in_context(ctx), batch.data[1].as_in_context(ctx)
            fake_out = netG(real_in)
            fake_concat = image_pool.query(nd.Concat(real_in, fake_out, dim=1))
            with autograd.record():
                # Train with fake images
                output = netD(fake_concat) #?????????????????? 这里把x和fake一同送入D,是Conditional GAN的体现?如何理解这里的条件概率?
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([fake_label,],[output,])  ## metric应该何时update???
                # Train with real images
                real_concat = image_pool.query(nd.Concat(real_in, real_out, dim=1))
                output = netD(real_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD = (errD_fake + errD_real) * 0.5 ## 如论文所述,D loss乘以0.5以降低相对G的更新速率
            with autograd.record():
                fake_out = netG(real_in)    # 这里的G为什么没有体现出Conditional GAN??  ####### 重要 #######
                #fake_concat = image_pool.query(nd.Concat(real_in, fake_out, dim=1))
                # 注意:image_pool只用于记录判别器
                fake_concat = nd.Concat(real_in, fake_out)  # Conditional GAN的先验:real_in,即 x
                output = netD(fake_concat)
                errG = GAN_loss(output, real_label) + lamda * L1_loss(real_out, fake_out)
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('Epoch {}, lr {:.6f}, D loss: {:3f}, G loss:{:3f}, binary training acc: {:2f}, at iter {}, Speed: {} samples/s'.format(
                epoch, trainerD.learning_rate,  errD.mean().asscalar(), errG.mean().asscalar(), acc, iter, 0.1*batch_size/ (time.time()-btic)))
            btic = time.time()
        if epoch % period == 0:
            trainerD.set_learning_rate(trainerD.learning_rate * lr_decay)
            trainerG.set_learning_rate(trainerG.learning_rate * lr_decay)
        if epoch % 100 == 0:
            print('[+]saving checkpoints to {}'.format(ckpt))
            netG.save_parameters(join(ckpt, 'pixel_netG_epoch_{}.params'.format(epoch)))
            netD.save_parameters(join(ckpt, 'pixel_netD_epoch_{}.params'.format(epoch)))
        name, epoch_acc = metric.get()
        logging.info('\n[+]binary training accuracy at epoch %d %s=%f' % (epoch, name, epoch_acc))
        logging.info('[+]time: {:3f}'.format(time.time() - epoch_tic))
Ejemplo n.º 2
    def hybrid_forward(self, F, x, *args, **kwargs):
        score_map = self.score_branch(x)
        geo_map = self.geo_branch(x) * self.text_scale
        angle_map = (self.theta_branch(x) - 0.5) * np.pi / 2.
        geometry_map = F.Concat(geo_map, angle_map, dim=1)

        return score_map, geometry_map
Ejemplo n.º 3
    def hybrid_forward(self, F, x, *args, **kwargs):
        h, w = x.shape[2:]
        # unet
        c1, c2, c3, c4 = self.base_forward(x)
        # stage 5
        # g0 = F.contrib.BilinearResize2D(c4, self.crop_size//16, self.crop_size//16)
        g0 = c4
        c1_1 = self.conv_stage1(F.Concat(g0, c3, dim=1))
        h1 = self.conv_stage1_3(c1_1)

        g1 = F.contrib.BilinearResize2D(h1, h//8, w//8)
        c2_2 = self.conv_stage2(F.Concat(g1, c2, dim=1))
        h2 = self.conv_stage2_3(c2_2)

        g2 = F.contrib.BilinearResize2D(h2, h//4, w//4)
        c3_3 = self.conv_stage3(F.Concat(g2, c1))
        h3 = self.conv_stage3_3(c3_3)

        F_score, F_geometry = self.head(h3)

        return F_score, F_geometry
Ejemplo n.º 4
 def forward(self, x, y):
     x = self.conv1(x)
     x = self.conv2(x)
     x = self.pool1(x)
     x = self.conv3(x)
     x = self.conv4(x)
     x = self.pool2(x)
     x = self.conv5(x)
     x = self.conv6(x)
     x = self.conv7(x)
     x = self.pool3(x)
     x = self.fc1(x)
     y = self.fc2(y)
     y = self.lstm1(y)
     y = y[:, -1, :]
     output = nd.Concat(x, y, dim=1)
     output = self.fc3(output)
     output = nd.softmax(output)
     return output
Ejemplo n.º 5
 def forward(self, x):
     p1 = self.p1_conv_1(x)
     p2 = self.p2_conv_2(self.p2_conv_1(x))
     p3 = self.p3_conv_2(self.p3_conv_1(x))
     p4 = self.p4_conv_2(self.p4_pool_1(x))
     return nd.Concat(p1, p2, p3, p4, dim=1)