Example #1
0
File: test.py Project: yongxinw/zsl
class Tester(object):
    def __init__(self, args):
        super(Tester, self).__init__()

        self.args = args

        self.model = AutoEncoder(args)
        self.model.load_state_dict(torch.load(args.checkpoint))
        self.model.cuda()
        self.model.eval()

        self.result = {}

        self.train_dataset = CUBDataset(split='train')
        self.test_dataset = CUBDataset(split='test')
        self.val_dataset = CUBDataset(split='val')

        self.train_loader = DataLoader(dataset=self.train_dataset,
                                       batch_size=args.batch_size)
        self.test_loader = DataLoader(dataset=self.test_dataset,
                                      batch_size=args.batch_size)
        self.val_loader = DataLoader(dataset=self.val_dataset,
                                     batch_size=100,
                                     shuffle=True)

        train_cls = self.train_dataset.get_classes('train')
        test_cls = self.test_dataset.get_classes('test')
        print("Load class")
        print(train_cls)
        print(test_cls)

        self.zsl = ZSLPrediction(train_cls, test_cls)

    # def tSNE(self):

    def conse_prediction(self, mode='test'):
        def pred(recon_x, z_tilde, output):
            cls_score = output.detach().cpu().numpy()
            print(cls_score)
            pred = self.zsl.conse_wordembedding_predict(
                cls_score, self.args.conse_top_k)
            return pred

        self.get_features(mode=mode, pred_func=pred)

        if (mode + '_pred') in self.result:

            target = self.result[mode + '_label']
            pred = self.result[mode + '_pred']
            print(target)
            print(pred)
            acc = np.sum(target == pred)
            print(acc)
            total = target.shape[0]
            print(total)
            return acc / float(total)
        else:
            raise NotImplementedError

    def knn_prediction(self, mode='test'):
        self.get_features(mode=mode, pred_func=None)

        if (mode + '_feature') in self.result:
            features = self.result[mode + '_feature']
            labels = self.result[mode + '_label']
            print(labels)
            self.zsl.construct_nn(features,
                                  labels,
                                  k=5,
                                  metric='cosine',
                                  sample_num=5)
            pred = self.zsl.nn_predict(features)

            acc = np.sum(labels == pred)
            total = labels.shape[0]

            return acc / float(total)
        else:
            raise NotImplementedError

    def tSNE(self, mode='train'):
        self.get_features(mode=mode, pred_func=None)

        total_num = self.result[mode + '_feature'].shape[0]

        random_index = np.random.permutation(total_num)

        random_index = random_index[:30]

        self.zsl.tSNE_visualization(self.result[mode+'_feature'][random_index,:], \
                                    self.result[mode+'_label'][random_index], \
                                    mode=mode,
                                    file_name= self.args.tsne_out)

    def get_features(self, mode='test', pred_func=None):
        self.model.eval()
        if pred_func is None and (mode + '_feature') in self.result:
            print("Use cached result")
            return
        if pred_func is not None and (mode + '_pred') in self.result:
            print("Use cached result")
            return

        if mode == 'train':
            loader = self.train_loader
        elif mode == 'test':
            loader = self.test_loader

        all_z = []
        all_label = []
        all_pred = []

        for data in tqdm(loader):
            # if idx == 3:
            #     break

            images = Variable(data['image64_crop'].cuda())
            target = Variable(data['class_id'].cuda())

            recon_x, z_tilde, output = self.model(images)
            target = target.detach().cpu().numpy()

            output = F.softmax(output, dim=1)

            all_label.append(target)
            all_z.append(z_tilde.detach().cpu().numpy())

            if pred_func is not None:
                pred = pred_func(recon_x, z_tilde, output)
                all_pred.append(pred)

        self.result[mode + '_feature'] = np.vstack(all_z)  # all features
        # print(all_label)
        self.result[mode + '_label'] = np.hstack(all_label)  # all test label

        if pred_func is not None:
            self.result[mode + '_pred'] = np.hstack(all_pred)
            print(self.result[mode + '_pred'].shape)
        print(self.result[mode + '_feature'].shape)
        print(self.result[mode + '_label'].shape)

    def validation_recon(self):
        self.model.eval()
        for idx, data in enumerate(self.val_loader):
            if idx == 1:
                break

            images = Variable(data['image64_crop'].cuda())
            recon_x, z_tilde, output = self.model(images)

            all_recon_images = recon_x.detach().cpu().numpy()  #N x 3 x 64 x 64
            all_origi_images = data['image64_crop'].numpy()  #N x 3 x 64 x 64

            for i in range(all_recon_images.shape[0]):
                imsave(
                    './recon/recon' + str(i) + '.png',
                    np.transpose(np.squeeze(all_origi_images[i, :, :, :]),
                                 [1, 2, 0]))
                imsave(
                    './recon/orig' + str(i) + '.png',
                    np.transpose(np.squeeze(all_recon_images[i, :, :, :]),
                                 [1, 2, 0]))

    def test_nn_image(self):
        self.get_features(mode='test', pred_func=None)
        self.get_features(mode='train', pred_func=None)

        N = 100
        random_index = np.random.permutation(
            self.result['test_feature'].shape[0])[:N]

        from sklearn.neighbors import NearestNeighbors

        neigh = NearestNeighbors()
        neigh.fit(self.result['train_feature'])

        test_feature = self.result['test_feature'][random_index, :]
        _, pred_index = neigh.kneighbors(test_feature, 1)

        for i in range(N):
            test_index = random_index[i]

            data = self.test_dataset[test_index]
            image = data['image64_crop'].numpy()  #1 x 3 x 64 x 64
            print(image.shape)
            imsave('./nn_image/test' + str(i) + '.png',
                   np.transpose(np.squeeze(image), [1, 2, 0]))

            train_index = pred_index[i][0]
            print(train_index)
            data = self.train_dataset[train_index]
            image = data['image64_crop'].numpy()  #1 x 3 x 64 x 64
            print(image.shape)
            imsave('./nn_image/train' + str(i) + '.png',
                   np.transpose(np.squeeze(image), [1, 2, 0]))
Example #2
0
class Trainer(object):
    def __init__(self, args):
        # load network
        self.G = AutoEncoder(args)
        self.D = Discriminator(args)
        self.G.weight_init()
        self.D.weight_init()
        self.G.cuda()
        self.D.cuda()
        self.criterion = nn.MSELoss()

        # load data
        self.train_dataset = CUBDataset(split='train')
        self.valid_dataset = CUBDataset(split='val')
        self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=args.batch_size)
        self.valid_loader = DataLoader(dataset=self.valid_dataset, batch_size=args.batch_size)

        # Optimizers
        self.G_optim = optim.Adam(self.G.parameters(), lr = args.lr_G)
        self.D_optim = optim.Adam(self.D.parameters(), lr = 0.5 * args.lr_D)
        self.G_scheduler = StepLR(self.G_optim, step_size=30, gamma=0.5)
        self.D_scheduler = StepLR(self.D_optim, step_size=30, gamma=0.5)

        # Parameters
        self.epochs = args.epochs
        self.batch_size = args.batch_size
        self.z_var = args.z_var 
        self.sigma = args.sigma
        self.lambda_1 = args.lambda_1
        self.lambda_2 = args.lambda_2

        log_dir = os.path.join(args.log_dir, datetime.now().strftime("%m_%d_%H_%M_%S"))
        # if not os.path.isdir(log_dir):
            # os.makedirs(log_dir)
        self.writter = SummaryWriter(log_dir)

    def train(self):
        global_step = 0
        self.G.train()
        self.D.train()
        ones = Variable(torch.ones(self.batch_size, 1).cuda())
        zeros = Variable(torch.zeros(self.batch_size, 1).cuda())

        for epoch in range(self.epochs):
            self.G_scheduler.step()
            self.D_scheduler.step()
            print("training epoch {}".format(epoch))
            all_num = 0.0
            acc_num = 0.0
            images_index = 0
            for data in tqdm(self.train_loader):
                images = Variable(data['image64'].cuda())
                target_image = Variable(data['image64'].cuda()) 
                target = Variable(data['class_id'].cuda())
                recon_x, z_tilde, output = self.G(images)
                z = Variable((self.sigma*torch.randn(z_tilde.size())).cuda())
                log_p_z = log_density_igaussian(z, self.z_var).view(-1, 1)
                ones = Variable(torch.ones(images.size()[0], 1).cuda())
                zeros = Variable(torch.zeros(images.size()[0], 1).cuda())

                # ======== Train Discriminator ======== #
                D_z = self.D(z)
                D_z_tilde = self.D(z_tilde)
                D_loss = F.binary_cross_entropy_with_logits(D_z+log_p_z, ones) + \
                    F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, zeros)

                total_D_loss = self.lambda_1*D_loss
                self.D_optim.zero_grad()
                total_D_loss.backward(retain_graph=True)
                self.D_optim.step()

                # ======== Train Generator ======== #
                recon_loss = F.mse_loss(recon_x, target_image, reduction='sum').div(self.batch_size)
                G_loss = F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, ones)
                class_loss = F.cross_entropy(output, target)
                total_G_loss = recon_loss + self.lambda_1*G_loss + self.lambda_2*class_loss
                self.G_optim.zero_grad()
                total_G_loss.backward()
                self.G_optim.step()

                # ======== Compute Classification Accuracy ======== #
                values, indices = torch.max(output, 1)
                acc_num += torch.sum((indices == target)).cpu().item()
                all_num += len(target)
                
                # ======== Log by TensorBoardX
                global_step += 1
                if (global_step + 1) % 10 == 0:
                    self.writter.add_scalar('train/recon_loss', recon_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/G_loss', G_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/D_loss', D_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/classify_loss', class_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/total_G_loss', total_G_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/acc', acc_num/all_num, global_step)
                    if images_index < 5 and torch.rand(1) < 0.5:
                        self.writter.add_image('train_output_{}'.format(images_index), recon_x[0], global_step)
                        self.writter.add_image('train_target_{}'.format(images_index), target_image[0], global_step)
                        images_index += 1
            if epoch % 2 == 0:
                self.validate(global_step)

    def validate(self, global_step):
        self.G.eval()
        self.D.eval()
        acc_num = 0.0
        all_num = 0.0
        recon_loss = 0.0
        images_index = 0
        for data in tqdm(self.valid_loader):
            images = Variable(data['image64'].cuda())
            target_image = Variable(data['image64'].cuda()) 
            target = Variable(data['class_id'].cuda())
            recon_x, z_tilde, output = self.G(images)
            values, indices = torch.max(output, 1)
            acc_num += torch.sum((indices == target)).cpu().item()
            all_num += len(target)
            recon_loss += F.mse_loss(recon_x, target_image, reduction='sum').cpu().item()
            if images_index < 5:
                self.writter.add_image('valid_output_{}'.format(images_index), recon_x[0], global_step)
                self.writter.add_image('valid_target_{}'.format(images_index), target_image[0], global_step)
                images_index += 1

        self.writter.add_scalar('valid/acc', acc_num/all_num, global_step)
        self.writter.add_scalar('valid/recon_loss', recon_loss/all_num, global_step)