コード例 #1
0
    def __init__(self, config, args):
        self.config = config
        for k, v in list(args.__dict__.items()):
            setattr(self.config, k, v)
        setattr(self.config, 'save_dir', '{}_log'.format(self.config.dataset))

        self.dis = mnist_model.Discriminative(
            noise_size=config.noise_size, num_label=config.num_label).cuda()
        self.gen = mnist_model.Generator(image_size=config.image_size,
                                         noise_size=config.noise_size).cuda()

        self.dis_optimizer = optim.Adam(self.dis.parameters(),
                                        lr=config.dis_lr,
                                        betas=(0.5, 0.9999))
        self.gen_optimizer = optim.Adam(self.gen.parameters(),
                                        lr=config.gen_lr,
                                        betas=(0.0, 0.9999))

        self.pixelcnn = pixelcnn.PixelCNN(nr_resnet=3,
                                          disable_third=True,
                                          dropout_p=0.0,
                                          n_channel=1,
                                          image_wh=28).cuda()
        self.pixelcnn.load_state_dict(torch.load(config.pixelcnn_path))

        self.d_criterion = nn.CrossEntropyLoss()

        self.labeled_loader, self.unlabeled_loader, self.unlabeled_loader2, self.dev_loader, self.special_set = data.get_mnist_loaders(
            config)

        if not os.path.exists(self.config.save_dir):
            os.makedirs(self.config.save_dir)

        log_probs_list = []
        for dev_images, _ in self.dev_loader.get_iter():
            dev_images = Variable(dev_images.cuda(), requires_grad=False)
            dev_images = (dev_images - 0.5) / 0.5
            dev_images = dev_images.view(-1, 1, 28, 28)
            with torch.no_grad():
                logits = self.pixelcnn(dev_images)
            log_probs = -pixelcnn_loss.discretized_mix_logistic_loss_c1(
                dev_images.permute(0, 2, 3, 1),
                logits.permute(0, 2, 3, 1),
                sum_all=False)
            log_probs = log_probs.data.cpu().numpy()
            log_probs_list.append(log_probs)
        log_probs = np.concatenate(log_probs_list, axis=0)

        self.unl_ploss_stats = log_probs.min(), log_probs.max(
        ), log_probs.mean(), log_probs.var()
        cut_point = int(log_probs.shape[0] * 0.1)
        self.ploss_th = float(np.partition(log_probs, cut_point)[cut_point])
        print('ploss_th', self.ploss_th)

        print(self.dis)
コード例 #2
0
    def __init__(self, config, args):
        self.config = config
        for k, v in args.__dict__.items():
            setattr(self.config, k, v)
        setattr(self.config, 'save_dir', '{}_log'.format(self.config.dataset))

        self.dis = mnist_model.Discriminative(noise_size=config.noise_size, num_label=config.num_label).cuda()
        self.gen = mnist_model.Generator(image_size=config.image_size, noise_size=config.noise_size).cuda()

        self.dis_optimizer = optim.Adam(self.dis.parameters(), lr=config.dis_lr, betas=(0.5, 0.9999))
        self.gen_optimizer = optim.Adam(self.gen.parameters(), lr=config.gen_lr, betas=(0.0, 0.9999))

        # self.pixelcnn = pixelcnn.PixelCNN(nr_resnet=3, disable_third=True, dropout_p=0.0, n_channel=1,
        #                                   image_wh=28).cuda()
        # self.pixelcnn.load_state_dict(torch.load(config.pixelcnn_path))

        self.d_criterion = nn.CrossEntropyLoss()

        self.labeled_loader, self.unlabeled_loader, self.unlabeled_loader2, self.dev_loader, self.special_set = data.get_mnist_loaders(
            config)

        if not os.path.exists(self.config.save_dir):
            os.makedirs(self.config.save_dir)
コード例 #3
0
    def __init__(self, config, args):
        super().__init__()
        self.config = config
        if args is not None:
            for k, v in args.__dict__.items():
                setattr(self.config, k, v)
        setattr(
            self.config, 'save_dir',
            '{}/{}_{}'.format(self.config.log_root, self.config.dataset,
                              self.config.suffix))
        if hasattr(config, 'inherit'):
            setattr(
                self.config, 'inherit_dir',
                '{}/{}_{}'.format(self.config.log_root, self.config.dataset,
                                  self.config.inherit))

        if not os.path.exists(self.config.save_dir):
            os.makedirs(self.config.save_dir)
        log_path = os.path.join(
            self.config.save_dir,
            '{}_{}_log.txt'.format(self.config.dataset, self.config.suffix))
        self.logger = open(log_path, 'a')

        disp_str = ''
        for attr in sorted(dir(self.config), key=lambda x: len(x)):
            if not attr.startswith('__'):
                disp_str += '{} : {}\n'.format(attr,
                                               getattr(self.config, attr))
        self.logger.write(disp_str)
        self.logger.flush()
        sys.stdout.write(disp_str)
        sys.stdout.flush()

        self.dis = mnist_model.Discriminative(
            noise_size=config.noise_size, num_label=config.num_label).cuda()
        self.gen = mnist_model.Generator(image_size=config.image_size,
                                         noise_size=config.noise_size,
                                         num_label=config.num_label).cuda()
        self.enc = mnist_model.Encoder(image_size=config.image_size,
                                       noise_size=config.noise_size,
                                       num_label=config.num_label,
                                       output_params=True).cuda()
        self.smp = mnist_model.Sampler(noise_size=config.noise_size).cuda()

        self.dis_optimizer = optim.Adam(self.dis.parameters(),
                                        lr=config.dis_lr,
                                        betas=(0.5, 0.9999))
        self.gen_optimizer = optim.Adam(self.gen.parameters(),
                                        lr=config.gen_lr,
                                        betas=(0.0, 0.9999))
        self.enc_optimizer = optim.Adam(self.enc.parameters(),
                                        lr=config.enc_lr,
                                        betas=(0.0, 0.9999))
        self.smp_optimizer = optim.Adam(self.smp.parameters(),
                                        lr=config.smp_lr,
                                        betas=(0.5, 0.9999))

        self.d_criterion = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCELoss()

        iter = self.load_checkpoint(self.config.save_dir)
        if iter == 0 and hasattr(config, 'inherit'):
            self.load_checkpoint(self.config.inherit_dir)
            self.iter_cnt = 0

        self.labeled_loader, self.unlabeled_loader, self.unlabeled_loader2, self.dev_loader = data.get_mnist_loaders(
            config)