Пример #1
0
def main():
    config = get_config()

    loader = Dataload()
    train, label_train, test, label_test = loader.class_train_test(path_feat_data=config.feat_path,
                                                                   batch_size=config.batch_size,
                                                                   shuffle=True)

    model = Classifier(batch_size=config.batch_size)
    trainer = TrainerClass(model, model_dir=config.model_dir, log_dir=config.log_dir)
    trainer.train(list(zip(train, label_train)), lr=config.lr, n_epoch=config.epochs)
    model.eval()
    trainer.test((test, label_test))
Пример #2
0
def main():
    config = get_config()

    loader = Dataload()
    train, label_train, test, label_test = loader.class_train_test(path_feat_data=config.feat_path,
                                                                   batch_size=config.batch_size,
                                                                   train_ratio=1,
                                                                   shuffle=True)


    ae = AE(13*3)
    ae.load_state_dict(torch.load(config.ae_model_path))
    ae.eval()

    classifier = Classifier(13*3, num_classes=109, batch_size=50)
    classifier.load_state_dict(torch.load(config.class_model_path))
    classifier.eval()

    trainer = TrainerClass(classifier)
    unnoise_data = ae(train)
    trainer.test((unnoise_data, label_train))
def main():
	#TODO: Get args
	# python3 train_fixmatch.py --checkpoint-path ./checkpoint_path/model.pth --batch-size 1 --num-epochs 1 --num-steps 1 --train-from-start 1 --dataset-folder ./dataset
	parser = argparse.ArgumentParser()
	parser.add_argument('--checkpoint-path', type=str, default= "./checkpoints/model_fm_transfer.pth.tar")
	parser.add_argument('--transfer-path', type=str, default= "./checkpoints/model_transfer.pth.tar")
	parser.add_argument('--best-path', type= str, default= "./checkpoints/model_barlow_best.pth.tar")
	parser.add_argument('--batch-size', type=int, default= 64)
	parser.add_argument('--num-epochs', type=int, default= 10)
	parser.add_argument('--num-steps', type=int, default= 10)
	parser.add_argument('--train-from-start', type= int, default= 1)
	parser.add_argument('--dataset-folder', type= str, default= "./dataset")
	parser.add_argument('--new-dataset-folder', type= str, default= "./dataset")
	parser.add_argument('--learning-rate', type = float, default= 0.01)
	parser.add_argument('--threshold', type = float, default= 0.5)
	parser.add_argument('--mu', type= int, default= 7)
	parser.add_argument('--lambd', type= int, default= 1)
	parser.add_argument('--momentum', type= float, default= 0.9)
	parser.add_argument('--weight-decay', type= float, default= 0.001)
	parser.add_argument('--layers', type= int, default= 18)
	parser.add_argument('--fine-tune', type= int, default= 1)
	parser.add_argument('--new-data', type= int, default= 0)
	args = parser.parse_args()

	dataset_folder = args.dataset_folder
	batch_size_labeled = args.batch_size
	mu = args.mu
	batch_size_unlabeled = mu * args.batch_size
	batch_size_val = 256 #5120
	n_epochs = args.num_epochs
	n_steps = args.num_steps
	num_classes = 800
	threshold = args.threshold
	learning_rate = args.learning_rate
	momentum = args.momentum
	lamd = args.lambd
	tau = 0.95
	weight_decay = args.weight_decay
	checkpoint_path = args.checkpoint_path
	train_from_start = args.train_from_start
	n_layers = args.layers

	if torch.cuda.is_available():
		device = torch.device("cuda")
	else:
		device = torch.device("cpu")

	# print("pwd: ", os.getcwd())
	train_transform, val_transform = get_transforms()

	if args.new_data == 0:
		labeled_train_dataset = CustomDataset(root= args.dataset_folder, split = "train", transform = train_transform)
	else:
		labeled_train_dataset = CustomDataset(root= args.new_dataset_folder, split = "train_new", transform = train_transform)
	# labeled_train_dataset = CustomDataset(root= dataset_folder, split = "train", transform = train_transform)
	unlabeled_train_dataset = CustomDataset(root= dataset_folder, 
											split = "unlabeled", 
											transform = TransformFixMatch(mean = 0, std = 0))#TODO
											
	val_dataset = CustomDataset(root= dataset_folder, split = "val", transform = val_transform)

	labeled_train_loader = DataLoader(labeled_train_dataset, batch_size= batch_size_labeled, shuffle= True, num_workers= 4)
	unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size= batch_size_unlabeled, shuffle= True, num_workers= 4)
	val_loader = DataLoader(val_dataset, batch_size= batch_size_val, shuffle= False, num_workers= 4)



	labeled_iter = iter(labeled_train_loader)
	unlabeled_iter = iter(unlabeled_train_loader)


	model = wide_resnet50_2(pretrained=False, num_classes = 800)
	classifier = Classifier(ip= 2048, dp = 0)
	start_epoch = 0

	checkpoint = torch.load(args.transfer_path, map_location= device)
	model.load_state_dict(checkpoint['model_state_dict'])
	classifier.load_state_dict(checkpoint['classifier_state_dict'])

	param_groups = [dict(params=classifier.parameters(), lr=args.learning_rate)]

	if args.fine_tune:
		param_groups.append(dict(params=model.parameters(), lr=args.learning_rate))

	optimizer = torch.optim.SGD(param_groups, 
								lr = learning_rate,
								momentum= momentum,
								nesterov= True,
								weight_decay= weight_decay)

	scheduler = get_cosine_schedule_with_warmup(optimizer, 0, num_training_steps= n_epochs * n_steps)

	if torch.cuda.device_count() > 1:
		print("Let's use", torch.cuda.device_count(), "GPUs!")
		model = torch.nn.DataParallel(model)
		classifier = torch.nn.DataParallel(classifier)

	if train_from_start == 0:
		assert os.path.isfile(checkpoint_path), "Error: no checkpoint directory found!"
		print("Restoring model from checkpoint")
		# args.out = os.path.dirname(args.resume)
		checkpoint = torch.load(checkpoint_path)
		# best_acc = checkpoint['best_acc']
		start_epoch = checkpoint['epoch'] - 1
		model.load_state_dict(checkpoint['backbone_state_dict'])
		classifier.load_state_dict(checkpoint['classifier_state_dict'])
		optimizer.load_state_dict(checkpoint['optimizer'])
		scheduler.load_state_dict(checkpoint['scheduler'])

	model = model.to(device)
	classifier = classifier.to(device)
	

	model.train()
	losses = Average()
	losses_l = Average()
	losses_u = Average()
	mask_probs = Average()
	best_val_accuracy = 25.0 #TODO

	for epoch in tqdm(range(start_epoch, n_epochs)):
		if args.fine_tune:
			model.train()
			classifier.train()
		else:
			model.eval()
			classifier.train()

		for batch_idx in tqdm(range(n_steps)):
			try:
				img_lab, targets_lab = labeled_iter.next()
			except:
				labeled_iter = iter(labeled_train_loader)
				img_lab, targets_lab = labeled_iter.next()

			try:
				unlab, _ = unlabeled_iter.next()
				img_weak = unlab[0]
				img_strong = unlab[1]
			except:
				unlabeled_iter = iter(unlabeled_train_loader)
				unlab, _ = unlabeled_iter.next()
				img_weak = unlab[0]
				img_strong = unlab[1]
			
			img_lab = img_lab.to(device)
			targets_lab = targets_lab.to(device)
			img_weak = img_weak.to(device)
			img_strong = img_strong.to(device)

			img_cat = torch.cat((img_lab, img_weak, img_strong), dim = 0)
			logits_cat = classifier(model(img_cat))
			logits_lab = logits_cat[:batch_size_labeled]
			# print(logits_lab.size())
			logits_unlab = logits_cat[batch_size_labeled:]
			# print(logits_unlab)

			logits_weak, logits_strong = torch.chunk(logits_unlab, chunks= 2, dim = 0)

			pseudo_label = torch.softmax(logits_weak.detach()/tau, dim= 1)
			max_probs, targets_unlab = torch.max(pseudo_label, dim= 1)
			mask = max_probs.ge(threshold).float()
			
			loss_labeled = F.cross_entropy(logits_lab, targets_lab, reduction='mean')

			# print("CE: ", F.cross_entropy(logits_strong, targets_unlab, reduction= 'none').size())

			loss_unlabeled = (F.cross_entropy(logits_strong, targets_unlab, reduction= 'none') * mask).mean()

			# print("Loss labelled, loss unlabelled: ", loss_labeled, loss_unlabeled)

			loss_total = loss_labeled + lamd * loss_unlabeled

			# print("Total loss: ", loss_total)
			# loss_epoch += loss_total
			# loss_lab_epoch += loss_labeled
			# loss_unlab_epoch += loss_unlabeled
			losses.update(loss_total.item())
			losses_l.update(loss_labeled.item())
			losses_u.update(loss_unlabeled.item())
			mask_probs.update(mask.mean().item())

			optimizer.zero_grad()
			loss_total.backward()
			optimizer.step()
			scheduler.step()


			# break
			if batch_idx % 25 == 0:
				print(f"Epoch number: {epoch}, loss: {losses.avg}, loss lab: {losses_l.avg}, loss unlab: {losses_u.avg}, mask: {mask_probs.avg}, loss_here: {loss_total.item()}, best accuracy: {best_val_accuracy:.2f}", flush= True)
			# print(optimizer.param_groups[0]['lr'])
		

		save_checkpoint({
				'epoch': epoch + 1,
				'model_state_dict': model.state_dict(),
				'classifier_state_dict': model.state_dict(),
				'optimizer': optimizer.state_dict(),
				'scheduler': scheduler.state_dict(),
			}, checkpoint_path)

		model.eval()
		classifier.eval()
		with torch.no_grad():
			val_loss = 0
			val_size = 0
			total = 0
			correct = 0
			for batch in val_loader:
				logits_val = classifier(model(batch[0].to(device)))
				labels = batch[1].to(device)
				val_loss += F.cross_entropy(logits_val, labels)
				_, predicted = torch.max(logits_val.data, 1)
				total += labels.size(0)
				correct += (predicted == labels).sum().item()
				val_size += 1
				# break
		print(f"Val loss: {val_loss/val_size}, Accuracy: {(100 * correct / total):.2f}%", flush= True)
		if 100 * correct / total > best_val_accuracy:
			best_val_accuracy = 100 * correct / total
			best_val_loss = val_loss/val_size
			print(f"Saving the best model with {best_val_accuracy:.2f}% accuracy and {best_val_loss:.2f} loss", flush= True)
			save_checkpoint({
				'epoch': epoch + 1,
				'model_state_dict': model.state_dict(),
				'classifier_state_dict': classifier.state_dict(),
				'optimizer': optimizer.state_dict(),
				'scheduler': scheduler.state_dict(),
				'best_val_accuracy': best_val_accuracy,
				'best_val_loss': best_val_loss
			}, args.best_path)
		model.train()
		classifier.train()
Пример #4
0
    def _disentanglement_metric(self,
                                dataset,
                                method_names,
                                sample_size,
                                n_epochs=6000,
                                dataset_size=1000,
                                hidden_dim=256,
                                use_non_linear=False):

        #train models for all concerned methods and stor them in a dict
        methods = {}
        runtimes = {}
        for method_name in tqdm(
                method_names,
                desc=
                "Iterating over methods for the Higgins disentanglement metric"
        ):
            if method_name == "VAE":
                methods["VAE"] = self.model

            elif method_name == "PCA":
                start = time.time()
                print("Training PCA...")
                pca = decomposition.PCA(n_components=self.model.latent_dim,
                                        whiten=True,
                                        random_state=self.seed)
                if dataset.imgs.ndim == 4:
                    data_imgs = dataset.imgs[:, :, :, :]
                    print(f"Shape of data images: {data_imgs.shape}")
                    imgs_pca = np.reshape(
                        data_imgs,
                        (data_imgs.shape[0],
                         data_imgs.shape[3] * data_imgs.shape[1]**2))
                else:
                    data_imgs = dataset.imgs
                    imgs_pca = np.reshape(
                        dataset.imgs,
                        (data_imgs.shape[0], data_imgs.shape[1]**2))
                size = min(
                    3500 if
                    (len(data_imgs.shape) > 3 and data_imgs.shape[3]) > 1 else
                    25000, len(imgs_pca))

                idx = np.random.randint(len(imgs_pca), size=size)
                imgs_pca = imgs_pca[
                    idx, :]  #not enough memory for full dataset -> repeat with random subsets
                pca.fit(imgs_pca)
                methods["PCA"] = pca

                self.logger.info("Done")

                runtimes[method_name] = time.time() - start

            elif method_name == "ICA":
                start = time.time()
                print("Training ICA...")
                ica = decomposition.FastICA(n_components=self.model.latent_dim,
                                            max_iter=400,
                                            random_state=self.seed)
                if dataset.imgs.ndim == 4:
                    data_imgs = dataset.imgs[:, :, :, :]
                    print(f"Shape of data images: {data_imgs.shape}")
                    imgs_ica = np.reshape(
                        data_imgs,
                        (data_imgs.shape[0],
                         data_imgs.shape[3] * data_imgs.shape[1]**2))
                else:
                    data_imgs = dataset.imgs
                    imgs_ica = np.reshape(
                        dataset.imgs,
                        (data_imgs.shape[0], data_imgs.shape[1]**2))
                size = min(
                    1000 if
                    (len(data_imgs.shape) > 3 and data_imgs.shape[3]) > 1 else
                    2500, len(imgs_ica))
                idx = np.random.randint(len(imgs_ica), size=size)
                imgs_ica = imgs_ica[
                    idx, :]  #not enough memory for full dataset -> repeat with random subsets
                ica.fit(imgs_ica)
                methods["ICA"] = ica

                self.logger.info("Done")

                runtimes[method_name] = time.time() - start

            else:
                raise ValueError("Unknown method : {}".format(method_name))

        if self.use_wandb:
            try:
                wandb.log(runtimes)
            except:
                pass

        data_train, data_test = {}, {}

        for method in methods:
            data_train[method] = [], []
            data_test[method] = [], []

        #latent dim = length of z_b_diff for arbitrary method = output dimension of linear classifier
        latent_dim = self.model.latent_dim

        #generate dataset_size many training data points and 20% of that test data points
        for i in tqdm(range(dataset_size),
                      desc="Generating datasets for Higgins metric"):
            data = self._compute_z_b_diff_y(methods, sample_size, dataset)
            for method in methods:
                data_train[method][0].append(data[method][0])
                data_train[method][1].append(data[method][1])
            if i <= int(dataset_size * 0.5):

                data = self._compute_z_b_diff_y(methods, sample_size, dataset)
                for method in methods:
                    data_test[method][0].append(data[method][0])
                    data_test[method][1].append(data[method][1])

        test_acc = {"linear": {}}

        test_acc = {"logreg": {}, "linear": {}, "nonlinear": {}, "rf": {}}
        for model_class in ["linear", "nonlinear", "logreg", "rf"]:
            if model_class in ["linear", "nonlinear"]:
                model = Classifier(latent_dim,
                                   hidden_dim,
                                   len(dataset.lat_sizes),
                                   use_non_linear=True
                                   if model_class == "nonlinear" else False)

                model.to(self.device)
                model.train()

                #log softmax with NLL loss
                criterion = torch.nn.NLLLoss()
                optim = torch.optim.Adagrad(
                    model.parameters(),
                    lr=0.01 if model_class == "linear" else 0.001,
                    weight_decay=0 if model_class == "linear" else 1e-4)
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optim, 'min', patience=5000, min_lr=0.00001)

                for method in tqdm(
                        methods.keys(),
                        desc="Training classifiers for the Higgins metric"):
                    if method == "ICA":
                        optim = torch.optim.Adam(
                            model.parameters(),
                            lr=1 if model_class == "linear" else 0.001,
                            weight_decay=0
                            if model_class == "linear" else 1e-4)
                        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optim, 'min', patience=5000, min_lr=0.00001)
                    X_train, Y_train = data_train[method]
                    X_train, Y_train = torch.tensor(
                        X_train,
                        dtype=torch.float32), torch.tensor(Y_train,
                                                           dtype=torch.long)
                    X_train = X_train.to(self.device)
                    Y_train = Y_train.to(self.device)

                    X_test, Y_test = data_test[method]
                    X_test, Y_test = torch.tensor(
                        X_test,
                        dtype=torch.float32), torch.tensor(Y_test,
                                                           dtype=torch.long)
                    X_test = X_test.to(self.device)
                    Y_test = Y_test.to(self.device)

                    print(f'Training the classifier for model {method}')
                    for e in tqdm(
                            range(n_epochs if model_class ==
                                  "linear" else round(n_epochs / 2)),
                            desc=
                            "Iterating over epochs while training the Higgins classifier"
                    ):
                        model.train()
                        optim.zero_grad()

                        scores_train = model(X_train)
                        loss = criterion(scores_train, Y_train)
                        loss.backward()
                        optim.step()
                        scheduler.step(loss)

                        if (e + 1) % 2000 == 0:
                            model.eval()
                            with torch.no_grad():
                                scores_test = model(X_test)
                                test_loss = criterion(scores_test, Y_test)
                                tqdm.write(
                                    f'In this epoch {e+1}/{n_epochs}, Training loss: {loss.item():.4f}, Test loss: {test_loss.item():.4f}'
                                )
                                model.eval()
                                scores_train = model(X_train)
                                scores_test = model(X_test)
                                _, prediction_train = scores_train.max(1)
                                _, prediction_test = scores_test.max(1)

                                train_acc = (prediction_train == Y_train
                                             ).sum().float() / len(X_train)
                                test_acc[model_class][method] = (
                                    prediction_test
                                    == Y_test).sum().float() / len(X_test)
                                tqdm.write(
                                    f'Accuracy of {method} on training set: {train_acc.item():.4f}, test set: {test_acc[model_class][method].item():.4f}'
                                )
                            model.train()

                    model.eval()
                    with torch.no_grad():

                        scores_train = model(X_train)
                        scores_test = model(X_test)
                        _, prediction_train = scores_train.max(1)
                        _, prediction_test = scores_test.max(1)

                        train_acc = (prediction_train
                                     == Y_train).sum().float() / len(X_train)
                        test_acc[model_class][method] = (
                            prediction_test
                            == Y_test).sum().float() / len(X_test)
                        print(
                            f'Accuracy of {method} on training set: {train_acc.item():.4f}, test set: {test_acc[model_class][method].item():.4f}'
                        )

                    model.apply(weight_reset)

            elif model_class in ["logreg", "rf"]:

                for method in tqdm(
                        methods.keys(),
                        desc="Training classifiers for the Higgins metric"):
                    if model_class == "logreg":
                        classifier = linear_model.LogisticRegression(
                            max_iter=500, random_state=self.seed)
                    elif model_class == "rf":
                        classifier = sklearn.ensemble.RandomForestClassifier(
                            n_estimators=150)
                    X_train, Y_train = data_train[method]
                    X_test, Y_test = data_test[method]
                    classifier.fit(X_train, Y_train)
                    train_acc = np.mean(classifier.predict(X_train) == Y_train)
                    test_acc[model_class][method] = np.mean(
                        classifier.predict(X_test) == Y_test)
                    print(
                        f'Accuracy of {method} on training set: {train_acc:.4f}, test set: {test_acc[model_class][method].item():.4f}'
                    )

        return test_acc
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--checkpoint-path',
        type=str,
        default="./checkpoints/model_transfer_barlow_best.pth.tar")
    parser.add_argument('--dataset-folder', type=str, default="./dataset")
    parser.add_argument('--out-path', type=str, default="./representations/")
    parser.add_argument('--batch-size', type=int, default=512)
    parser.add_argument('--wide', type=int, default=0)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--final', type=int, default=0)
    args = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    unlabeled_train_dataset = CustomDataset(root=args.dataset_folder,
                                            split="unlabeled",
                                            transform=eval_transform)
    unlabeled_dataloader = DataLoader(unlabeled_train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=4)

    labeled_train_dataset = CustomDataset(root=args.dataset_folder,
                                          split="train",
                                          transform=eval_transform)
    labeled_dataloader = DataLoader(labeled_train_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=4)

    if args.wide:
        model = lightly.models.BarlowTwins(wide_resnet50_2(pretrained=False),
                                           num_ftrs=2048)
    else:
        model = lightly.models.BarlowTwins(resnet18(pretrained=False),
                                           num_ftrs=512)

    if args.wide == 1:
        classifier = Classifier(ip=2048, dp=args.dropout)
    else:
        classifier = Classifier(ip=512, dp=args.dropout)

    model = model.backbone

    checkpoint = torch.load(args.checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    classifier.load_state_dict(checkpoint['classifier_state_dict'])

    model = model.to(device)
    classifier = classifier.to(device)

    model.eval()
    classifier.eval()
    label_rep_model = torch.tensor([]).to(device)
    label_rep_clf = torch.tensor([]).to(device)
    unlabel_rep_model = torch.tensor([]).to(device)
    unlabel_rep_clf = torch.tensor([]).to(device)

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(labeled_dataloader)):
            img = batch[0].to(device)

            logits_model = model(img)
            logits_classifier = classifier(logits_model)

            label_rep_model = torch.cat((label_rep_model, logits_model), dim=0)
            label_rep_clf = torch.cat((label_rep_clf, logits_classifier),
                                      dim=0)

        print("Writing labeled representations to file", flush=True)
        lab_path = args.out_path + "lab_rep_model.pt"
        torch.save(label_rep_model.detach(), lab_path)

        lab_path = args.out_path + "lab_rep_clf.pt"
        torch.save(label_rep_clf.detach(), lab_path)

        for batch_idx, batch in enumerate(tqdm(unlabeled_dataloader)):
            img = batch[0].to(device)

            logits_model = model(img)
            logits_classifier = classifier(logits_model)

            unlabel_rep_model = torch.cat((unlabel_rep_model, logits_model),
                                          dim=0)
            unlabel_rep_clf = torch.cat((unlabel_rep_clf, logits_classifier),
                                        dim=0)

        print("Writing unlabeled representations to file", flush=True)
        unlab_path = args.out_path + "unlab_rep_model.pt"
        torch.save(unlabel_rep_model.detach(), unlab_path)

        unlab_path = args.out_path + "unlab_rep_clf.pt"
        torch.save(unlabel_rep_clf.detach(), unlab_path)
Пример #6
0
class Evaluator(CheckpointRunner):
    def __init__(self, options, logger: Logger, writer, shared_model=None):
        super().__init__(options,
                         logger,
                         writer,
                         training=False,
                         shared_model=shared_model)

    # noinspection PyAttributeOutsideInit
    def init_fn(self, shared_model=None, **kwargs):
        if self.options.model.name == "pixel2mesh":
            # Renderer for visualization
            self.renderer = MeshRenderer(self.options.dataset.camera_f,
                                         self.options.dataset.camera_c,
                                         self.options.dataset.mesh_pos)
            # Initialize distance module
            self.chamfer = ChamferDist()
            # create ellipsoid
            self.ellipsoid = Ellipsoid(self.options.dataset.mesh_pos)
            # use weighted mean evaluation metrics or not
            self.weighted_mean = self.options.test.weighted_mean
        else:
            self.renderer = None
        self.num_classes = self.options.dataset.num_classes

        if shared_model is not None:
            self.model = shared_model
        else:
            if self.options.model.name == "pixel2mesh":
                # create model
                self.model = P2MModel(self.options.model, self.ellipsoid,
                                      self.options.dataset.camera_f,
                                      self.options.dataset.camera_c,
                                      self.options.dataset.mesh_pos)
            elif self.options.model.name == "classifier":
                self.model = Classifier(self.options.model,
                                        self.options.dataset.num_classes)
            else:
                raise NotImplementedError("Your model is not found")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.gpus).cuda()

        # Evaluate step count, useful in summary
        self.evaluate_step_count = 0
        self.total_step_count = 0

    def models_dict(self):
        return {'model': self.model}

    def evaluate_f1(self, dis_to_pred, dis_to_gt, pred_length, gt_length,
                    thresh):
        recall = np.sum(dis_to_gt < thresh) / gt_length
        prec = np.sum(dis_to_pred < thresh) / pred_length
        return 2 * prec * recall / (prec + recall + 1e-8)

    def evaluate_chamfer_and_f1(self, pred_vertices, gt_points, labels):
        # calculate accurate chamfer distance; ground truth points with different lengths;
        # therefore cannot be batched
        batch_size = pred_vertices.size(0)
        pred_length = pred_vertices.size(1)
        for i in range(batch_size):
            gt_length = gt_points[i].size(0)
            label = labels[i].cpu().item()
            d1, d2, i1, i2 = self.chamfer(pred_vertices[i].unsqueeze(0),
                                          gt_points[i].unsqueeze(0))
            d1, d2 = d1.cpu().numpy(), d2.cpu().numpy(
            )  # convert to millimeter
            self.chamfer_distance[label].update(np.mean(d1) + np.mean(d2))
            self.f1_tau[label].update(
                self.evaluate_f1(d1, d2, pred_length, gt_length, 1E-4))
            self.f1_2tau[label].update(
                self.evaluate_f1(d1, d2, pred_length, gt_length, 2E-4))

    def evaluate_accuracy(self, output, target):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        top_k = [1, 5]
        maxk = max(top_k)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        for k in top_k:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            acc = correct_k.mul_(1.0 / batch_size)
            if k == 1:
                self.acc_1.update(acc)
            elif k == 5:
                self.acc_5.update(acc)

    def evaluate_step(self, input_batch):
        self.model.eval()

        # Run inference
        with torch.no_grad():
            # Get ground truth
            images = input_batch['images']

            out = self.model(images)

            if self.options.model.name == "pixel2mesh":
                pred_vertices = out["pred_coord"][-1]
                gt_points = input_batch["points_orig"]
                if isinstance(gt_points, list):
                    gt_points = [pts.cuda() for pts in gt_points]
                self.evaluate_chamfer_and_f1(pred_vertices, gt_points,
                                             input_batch["labels"])
            elif self.options.model.name == "classifier":
                self.evaluate_accuracy(out, input_batch["labels"])

        return out

    # noinspection PyAttributeOutsideInit
    def evaluate(self):
        self.logger.info("Running evaluations...")

        # clear evaluate_step_count, but keep total count uncleared
        self.evaluate_step_count = 0

        test_data_loader = DataLoader(self.dataset,
                                      batch_size=self.options.test.batch_size *
                                      self.options.num_gpus,
                                      num_workers=self.options.num_workers,
                                      pin_memory=self.options.pin_memory,
                                      shuffle=self.options.test.shuffle,
                                      collate_fn=self.dataset_collate_fn)

        if self.options.model.name == "pixel2mesh":
            self.chamfer_distance = [
                AverageMeter() for _ in range(self.num_classes)
            ]
            self.f1_tau = [AverageMeter() for _ in range(self.num_classes)]
            self.f1_2tau = [AverageMeter() for _ in range(self.num_classes)]
        elif self.options.model.name == "classifier":
            self.acc_1 = AverageMeter()
            self.acc_5 = AverageMeter()

        # Iterate over all batches in an epoch
        for step, batch in enumerate(test_data_loader):
            # Send input to GPU
            batch = {
                k: v.cuda() if isinstance(v, torch.Tensor) else v
                for k, v in batch.items()
            }

            # Run evaluation step
            out = self.evaluate_step(batch)

            # Tensorboard logging every summary_steps steps
            if self.evaluate_step_count % self.options.test.summary_steps == 0:
                self.evaluate_summaries(batch, out)

            # add later to log at step 0
            self.evaluate_step_count += 1
            self.total_step_count += 1

        for key, val in self.get_result_summary().items():
            scalar = val
            if isinstance(val, AverageMeter):
                scalar = val.avg
            self.logger.info("Test [%06d] %s: %.6f" %
                             (self.total_step_count, key, scalar))
            self.summary_writer.add_scalar("eval_" + key, scalar,
                                           self.total_step_count + 1)

    def average_of_average_meters(self, average_meters):
        s = sum([meter.sum for meter in average_meters])
        c = sum([meter.count for meter in average_meters])
        weighted_avg = s / c if c > 0 else 0.
        avg = sum([meter.avg
                   for meter in average_meters]) / len(average_meters)
        ret = AverageMeter()
        if self.weighted_mean:
            ret.val, ret.avg = avg, weighted_avg
        else:
            ret.val, ret.avg = weighted_avg, avg
        return ret

    def get_result_summary(self):
        if self.options.model.name == "pixel2mesh":
            return {
                "cd": self.average_of_average_meters(self.chamfer_distance),
                "f1_tau": self.average_of_average_meters(self.f1_tau),
                "f1_2tau": self.average_of_average_meters(self.f1_2tau),
            }
        elif self.options.model.name == "classifier":
            return {
                "acc_1": self.acc_1,
                "acc_5": self.acc_5,
            }

    def evaluate_summaries(self, input_batch, out_summary):
        self.logger.info("Test Step %06d/%06d (%06d) " % (self.evaluate_step_count,
                                                          len(self.dataset) // (
                                                                  self.options.num_gpus * self.options.test.batch_size),
                                                          self.total_step_count,) \
                         + ", ".join([key + " " + (str(val) if isinstance(val, AverageMeter) else "%.6f" % val)
                                      for key, val in self.get_result_summary().items()]))

        self.summary_writer.add_histogram("eval_labels",
                                          input_batch["labels"].cpu().numpy(),
                                          self.total_step_count)
        if self.renderer is not None:
            # Do visualization for the first 2 images of the batch
            render_mesh = self.renderer.p2m_batch_visualize(
                input_batch, out_summary, self.ellipsoid.faces)
            self.summary_writer.add_image("eval_render_mesh", render_mesh,
                                          self.total_step_count)