示例#1
0
    def train(self):
        config = self.config

        print(config.train_batch_size % len(self.unlabeled_loader))
        self.param_init()

        self.iter_cnt = 0
        iter, min_dev_incorrect = 0, 1e6
        monitor = OrderedDict()

        batch_per_epoch = int((len(self.unlabeled_loader) + config.train_batch_size - 1) // config.train_batch_size)
        while True:

            if iter % batch_per_epoch == 0:
                epoch = iter // batch_per_epoch
                if epoch >= config.max_epochs:
                    break
                epoch_ratio = float(epoch) / float(config.max_epochs)
                # use another outer max to prevent any float computation precision problem
                self.dis_optimizer.param_groups[0]['lr'] = config.dis_lr * max(0., min(3. * (1. - epoch_ratio), 1.))
                self.gen_optimizer.param_groups[0]['lr'] = config.gen_lr * max(0., min(3. * (1. - epoch_ratio), 1.))

            iter_vals = self._train(iter=iter)

            for k, v in iter_vals.items():
                if not (k) in monitor:
                    monitor[k] = 0.
                monitor[k] += v

            if iter % config.vis_period == 0:
                self.visualize()

            if iter % config.eval_period == 0:
                train_loss, train_incorrect = self.eval(self.labeled_loader)
                dev_loss, dev_incorrect = self.eval(self.dev_loader)

                min_dev_incorrect = min(min_dev_incorrect, dev_incorrect)
                disp_str = '#{}\ttrain: {:.4f}, {} | dev: {:.4f}, {} | best: {}'.format(
                        iter, train_loss, train_incorrect, dev_loss, dev_incorrect, min_dev_incorrect)
                for k, v in monitor.items():
                    disp_str += ' | {}: {:.4f}'.format(k, v / config.eval_period)

                disp_str += ' | lr: dis {:.5f}, gen {:.5f}'.format(
                    self.dis_optimizer.param_groups[0]['lr'], self.gen_optimizer.param_groups[0]['lr'])
                monitor = OrderedDict()

                print(disp_str)

                noise = Variable(torch.Tensor(400, self.config.noise_size).uniform_().cuda(), volatile=True)
                images = self.gen(noise)
                images = (images - 0.5) // 0.5
                images = images.view(-1, 1, 28, 28)
                logits = self.pixelcnn(images)
                log_probs = - pixelcnn_loss.discretized_mix_logistic_loss_c1(images.permute(0, 2, 3, 1), logits.permute(0, 2, 3, 1), sum_all=False).data.cpu()
                gen_ploss_stats = log_probs.min(), log_probs.max(), log_probs.mean(), log_probs.var()
                print ('gen stats', gen_ploss_stats)
                print ('unl stats', self.unl_ploss_stats)

            iter += 1
            self.iter_cnt += 1
示例#2
0
    def __init__(self, config, args):
        self.config = config
        for k, v in list(args.__dict__.items()):
            setattr(self.config, k, v)
        setattr(self.config, 'save_dir', '{}_log'.format(self.config.dataset))

        self.dis = mnist_model.Discriminative(
            noise_size=config.noise_size, num_label=config.num_label).cuda()
        self.gen = mnist_model.Generator(image_size=config.image_size,
                                         noise_size=config.noise_size).cuda()

        self.dis_optimizer = optim.Adam(self.dis.parameters(),
                                        lr=config.dis_lr,
                                        betas=(0.5, 0.9999))
        self.gen_optimizer = optim.Adam(self.gen.parameters(),
                                        lr=config.gen_lr,
                                        betas=(0.0, 0.9999))

        self.pixelcnn = pixelcnn.PixelCNN(nr_resnet=3,
                                          disable_third=True,
                                          dropout_p=0.0,
                                          n_channel=1,
                                          image_wh=28).cuda()
        self.pixelcnn.load_state_dict(torch.load(config.pixelcnn_path))

        self.d_criterion = nn.CrossEntropyLoss()

        self.labeled_loader, self.unlabeled_loader, self.unlabeled_loader2, self.dev_loader, self.special_set = data.get_mnist_loaders(
            config)

        if not os.path.exists(self.config.save_dir):
            os.makedirs(self.config.save_dir)

        log_probs_list = []
        for dev_images, _ in self.dev_loader.get_iter():
            dev_images = Variable(dev_images.cuda(), requires_grad=False)
            dev_images = (dev_images - 0.5) / 0.5
            dev_images = dev_images.view(-1, 1, 28, 28)
            with torch.no_grad():
                logits = self.pixelcnn(dev_images)
            log_probs = -pixelcnn_loss.discretized_mix_logistic_loss_c1(
                dev_images.permute(0, 2, 3, 1),
                logits.permute(0, 2, 3, 1),
                sum_all=False)
            log_probs = log_probs.data.cpu().numpy()
            log_probs_list.append(log_probs)
        log_probs = np.concatenate(log_probs_list, axis=0)

        self.unl_ploss_stats = log_probs.min(), log_probs.max(
        ), log_probs.mean(), log_probs.var()
        cut_point = int(log_probs.shape[0] * 0.1)
        self.ploss_th = float(np.partition(log_probs, cut_point)[cut_point])
        print('ploss_th', self.ploss_th)

        print(self.dis)
示例#3
0
    def _train(self, labeled=None, vis=False, iter=0):
        config = self.config
        self.dis.train()
        self.gen.train()

        lab_images, lab_labels = next(self.labeled_loader)
        lab_images, lab_labels = Variable(lab_images.cuda()), Variable(
            lab_labels.cuda())

        unl_images, _ = next(self.unlabeled_loader)
        unl_images = Variable(unl_images.cuda())

        noise = Variable(
            torch.Tensor(unl_images.size(0),
                         config.noise_size).uniform_().cuda())
        gen_images = self.gen(noise)

        lab_logits = self.dis(lab_images)
        unl_logits = self.dis(unl_images)
        gen_logits = self.dis(gen_images.detach())

        # Standard classification loss
        lab_loss = self.d_criterion(lab_logits, lab_labels)

        unl_logsumexp = log_sum_exp(unl_logits)
        gen_logsumexp = log_sum_exp(gen_logits)

        unl_acc = torch.mean(
            nn.functional.sigmoid(unl_logsumexp.detach()).gt(0.5).float())
        gen_acc = torch.mean(
            nn.functional.sigmoid(gen_logsumexp.detach()).gt(0.5).float())

        # This is the typical GAN cost, where sumexp(logits) is seen as the input to the sigmoid
        true_loss = -0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean(
            F.softplus(unl_logsumexp))
        fake_loss = 0.5 * torch.mean(F.softplus(gen_logsumexp))
        unl_loss = true_loss + fake_loss

        d_loss = lab_loss + unl_loss

        self.dis_optimizer.zero_grad()
        d_loss.backward()
        self.dis_optimizer.step()

        ##### train Gen and Enc
        unl_images, _ = next(self.unlabeled_loader2)
        unl_images = Variable(unl_images.cuda())
        noise = Variable(
            torch.Tensor(unl_images.size(0),
                         config.noise_size).uniform_().cuda())
        gen_images = self.gen(noise)

        unl_feat = self.dis(unl_images, feat=True)
        gen_feat = self.dis(gen_images, feat=True)
        fm_loss = torch.mean(
            (torch.mean(gen_feat, 0) - torch.mean(unl_feat, 0))**2)

        gen_loss = fm_loss

        if iter > 9000 and random.random() < config.p_loss_prob:
            noise = Variable(
                torch.Tensor(30, config.noise_size).uniform_().cuda())
            gen_images = self.gen(noise)
            gen_images = (gen_images - 0.5) / 0.5
            gen_images = gen_images.view(-1, 1, 28, 28)
            logits = self.pixelcnn(gen_images)
            log_probs = -pixelcnn_loss.discretized_mix_logistic_loss_c1(
                gen_images.permute(0, 2, 3, 1),
                logits.permute(0, 2, 3, 1),
                sum_all=False)
            p_loss = torch.max(
                log_probs - self.ploss_th,
                Variable(torch.cuda.FloatTensor(log_probs.size()).fill_(0.0)))
            non_zero_cnt = float((p_loss > 0).sum().data.cpu()[0])
            if non_zero_cnt > 0:
                p_loss = p_loss.sum() / non_zero_cnt * config.p_loss_weight
            else:
                p_loss = 0
        else:
            p_loss = 0

        loss = gen_loss + p_loss

        self.gen_optimizer.zero_grad()
        loss.backward()
        self.gen_optimizer.step()

        monitor_dict = OrderedDict([
            ('unl accuracy', unl_acc.data[0]),
            ('gen accuracy', gen_acc.data[0]), ('lab loss', lab_loss.data[0]),
            ('unl loss', unl_loss.data[0]), ('true loss', true_loss.data[0]),
            ('fake loss', fake_loss.data[0]), ('gen loss', gen_loss.data[0]),
            ('p loss', p_loss.data[0] if hasattr(p_loss, 'data') else 0.0)
        ])

        return monitor_dict