Exemple #1
0
def get_best_arc(n_samples=10, verbose=False):
    global args
    global controller
    global netG
    
    controller.eval()
    netG.eval()

    arcs = []
    inception_scores = []
    for i in range(n_samples):
        with torch.no_grad():
            controller()  # perform forward pass to generate a new architecture
        sample_arc = controller.sample_arc
        arcs.append(sample_arc)

        with torch.no_grad():
            score, _ = utils.get_inception_score(netG, sample_arc, args, 50)
        inception_scores.append(score)

        if verbose:
            print_arc(sample_arc)
            logging.info('score=' + str(score))
            logging.info('-' * 80)

    best_iter = np.argmax(inception_scores)
    best_arc = arcs[best_iter]
    best_score = inception_scores[best_iter]

    controller.train()
    netG.train()
    return best_arc, best_score
Exemple #2
0
        set_grad(g_net, True)
        set_grad(d_net, False)
        g_net.zero_grad()
        z = torch.randn(batch_size, nz)
        if cuda:
            z = z.cuda()
        fake_x = g_net(z)
        fake_labels = d_net(fake_x)
        fake_label = fake_labels.mean()
        fake_label.backward(mone)
        g_optimizer.step()
        g_loss = -fake_label

    #visulize inception score
    z = torch.randn(2500, nz)
    inception_scores = utils.get_inception_score(g_net, z)
    inception_score = np.array([inception_scores[0]])
    win = visutils.visualize_loss(epoch, inception_score, env, win)

    print "epoch is:[{}|{}],index is:[{}|{}],d_loss:{},g_loss:{},IS:{}".\
        format(epoch,epoch_num,i,len(dataloader),d_loss,g_loss,inception_score)

    if epoch % 10 == 0:
        z = torch.randn([batch_size, nz])
        if cuda:
            z = z.cuda()
        fake_x = g_net(z)
        vutils.save_image(fake_x.cpu().detach(),
                          '%s/fake_samples_epoch_%03d.png' %
                          (result_directory, epoch),
                          normalize=True)
Exemple #3
0
def train(g_net):
    # small_dataset = smallCifarDataset(root_dir)
    # dataloader = torch.utils.data.DataLoader(small_dataset,batch_size=batch_size,shuffle=True,drop_last=True)
    data = dset.CIFAR10(root="/home/lrh/dataset/cifar-10",
                        train=False,
                        download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5)),
                        ]))

    dataloader = torch.utils.data.DataLoader(data,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)

    g_optimizer = optim.Adam(g_net.parameters(), lr=lr, betas=beta)
    for i, data in enumerate(dataloader, 0):
        #     #with shape [batch_size,3,32,32]
        real_x = data[0]
        vutils.save_image(real_x.detach(),
                          '%s/real_samples.png' % (result_directory),
                          normalize=True,
                          nrow=20)
        break
    for epoch in range(9000, epoch_num):
        z = torch.randn(real_x.size(0), nz)
        #to GPU
        if cuda:
            real_x = real_x.cuda()
            z = z.cuda()

        g_optimizer.zero_grad()
        isopt = True
        pi = None
        for j in range(g_steps):
            fake_x = g_net(z)
            #[batch_size,3,32,32]
            loss, pi = loss_fn(real_x, fake_x, isopt, pi)
            loss.backward()
            g_optimizer.step()
            isopt = False
        if epoch % 100 == 0:
            z = torch.randn(64, nz)
            z = z.cuda()
            fake_x = g_net(z)
            vutils.save_image(fake_x.cpu().detach(),
                              '%s/fake_samples_epoch_%03d.png' %
                              (result_directory, epoch),
                              normalize=True)

            fake_x = fake_x.cpu().detach().numpy()

            IS = utils.get_inception_score(g_net)
            print "epoch is:[{}|{}],index is:[{}|{}],g_loss:{},IS is:{}".\
                format(epoch,epoch_num,\
                i,len(dataloader),loss,IS)
        if epoch % 1000 == 0:
            torch.save(g_net.state_dict(),
                       '%s/gnet_%03d.pkl' % (result_directory, epoch))
Exemple #4
0
def train_controller(epoch, baseline=None):
    global args
    global netG
    global controller
    global controller_optimizer

    logging.info('Epoch ' + str(epoch) + ': Training controller')

    netG.eval()

    controller.zero_grad()
    for i in range(args.controller_train_steps * args.controller_num_aggregate):
        start = time.time()

        controller()  # perform forward pass to generate a new architecture
        sample_arc = controller.sample_arc
        
        with torch.no_grad():
            score, _ = utils.get_inception_score(netG, sample_arc, args, 50)
        
        # detach to make sure that gradients aren't backpropped through the reward
        reward = score/12
        reward += args.controller_entropy_weight * controller.sample_entropy

        if baseline is None:
            baseline = score
        else:
            baseline -= (1 - args.controller_bl_dec) * (baseline - reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        loss = -1 * controller.sample_log_prob * (reward - baseline)

        if args.controller_skip_weight is not None:
            loss += args.controller_skip_weight * controller.skip_penaltys

        # Average gradient over controller_num_aggregate samples
        loss = loss / args.controller_num_aggregate

        loss.backward(retain_graph=True)

        end = time.time()

        # Aggregate gradients for controller_num_aggregate iterationa, then update weights
        if (i + 1) % args.controller_num_aggregate == 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), args.child_grad_bound)
            controller_optimizer.step()
            controller.zero_grad()

            if (i + 1) % (2 * args.controller_num_aggregate) == 0:
                learning_rate = controller_optimizer.param_groups[0]['lr']
                display = 'ctrl_step=' + str(i // args.controller_num_aggregate) + \
                          '\tloss=%.3f' % (loss.item()) + \
                          '\tent=%.2f' % (controller.sample_entropy.item()) + \
                          '\tlr=%.4f' % (learning_rate) + \
                          '\t|g|=%.4f' % (grad_norm) + \
                          '\tacc=%.4f' % (score) + \
                          '\tbl=%.2f' % (baseline) + \
                          '\ttime=%.2fit/s' % (1. / (end - start))
                logging.info(display)

    netG.train()
    return baseline
Exemple #5
0
def train(d_net, g_net):
    #prepare true samples

    data = dset.CIFAR10(root=conf.root_path,
                        download=False,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5)),
                        ]))
    dataloader = torch.utils.data.DataLoader(data,
                                             batch_size=conf.batch_size,
                                             shuffle=True,
                                             drop_last=True)
    #generate fake samples
    d_optimizer = optim.Adam(d_net.parameters(), lr=conf.lr, betas=conf.beta)
    g_optimizer = optim.Adam(g_net.parameters(), lr=conf.lr, betas=conf.beta)
    d_lo = g_lo = 0
    criterion = nn.BCELoss()
    if conf.fixz:
        size = conf.batch_size, 100, 1, 1
        z = sample(size)
    for epoch in range(conf.epoch_num):
        for i, data in enumerate(dataloader, 0):
            #with shape [batch_size,3,32,32]
            real_x = data[0]

            #to GPU
            if conf.cuda:
                real_x = real_x.cuda()

            d_optimizer.zero_grad()
            num_samples = real_x.size(0)
            label = torch.full((num_samples, ), 1).cuda()
            real_logits = d_net(real_x)
            d_l_r = criterion(real_logits, label)
            d_l_r.backward()
            if not conf.fixz:
                size = num_samples, 100, 1, 1
                z = sample(size)
            if conf.cuda:
                z = z.cuda()
            fake_x = g_net(z)
            label.fill_(0)
            fake_logits = d_net(fake_x)
            d_l_f = criterion(fake_logits, label)
            d_l_f.backward()
            d_optimizer.step()
            d_lo = d_l_r + d_l_f

            g_optimizer.zero_grad()
            fake_x = g_net(z)
            fake_logits = d_net(fake_x)
            label.fill_(1)
            g_l = criterion(fake_logits, label)
            g_l.backward()
            g_optimizer.step()
            g_lo = g_l

            if conf.debug:
                print "epoch is:[{}|{}],index is :[{}|{}],d_loss:{},g_loss:{}".\
                format(epoch,conf.epoch_num,\
                i,len(dataloader),d_lo,g_lo)
        #after each epoch,we visulize the result
        if conf.debug:
            for para in g_net.parameters():
                print torch.mean(para.grad)
        fake_x = g_net(z)
        print "d_loss:{},g_loss:{}".format(d_lo, g_lo)
        vutils.save_image(fake_x.cpu().detach(),
                          '%s/fake_samples_epoch_%03d.png' %
                          (conf.result_directory, epoch),
                          normalize=True)

        inception_scores = utils.get_inception_score(g_net)
        inception_score = np.array([inception_scores[0]])
        print inception_score
        conf.win = visutils.visualize_loss(epoch, inception_score, conf.env,
                                           conf.win)

        if epoch % 50 == 0:
            torch.save(g_net.state_dict(),
                       '%s/gnet_%03d.pkl' % (conf.result_directory, epoch))
            torch.save(d_net.state_dict(),
                       '%s/dnet_%03d.pkl' % (conf.result_directory, epoch))