コード例 #1
0
def train(num_epochs, train_loader, FVI):
	FVI.train()

	ft_start_flag = 0
	for s in range(args.start_epoch, args.start_epoch + args.n_epochs + 1):
		train_loss = 0.
		train_error = 0.
		FVI.train()
		if s >= args.ft_start and ft_start_flag == 0:
			from lib.priors import f_prior_BNN
			print('Now fine-tuning on full size images')
			train_loader = ft_loader
			if args.f_prior == 'cnn_gp':
				FVI.prior = f_prior_BNN((H, W), device, num_channels_output=num_classes)
			ft_start_flag += 1
		for X, Y in tqdm(train_loader):
			x_t = X.to(device)
			y_t = Y.to(device)
			N_t = x_t.size(0)

			optimizer.zero_grad()

			f_samples, q_mean, q_cov, prior_mean, prior_cov = FVI(x_t)

			loss = - fELBO(y_t, f_samples, q_mean, q_cov, prior_mean, prior_cov, print_loss=True)

			loss.backward()
			train_loss += -loss.item()
			optimizer.step()
			_, _, train_acc_curr = numpy_metrics(FVI.predict(x_t, S=20).data.cpu().view(N_t, -1).numpy(), y_t.view(N_t, -1).data.cpu().numpy())
			train_error += 1 - train_acc_curr
			adjust_learning_rate(args.lr, 0.998, optimizer, s, args.final_epoch)
			del x_t, y_t, f_samples, q_mean, q_cov, prior_mean, prior_cov
		train_loss /= len(train_loader)
		train_error /= len(train_loader)
		print('Epoch: {} || Average Train Error: {:.5f} || Average Train Loss: {:.5f}'.format(s, train_error, train_loss))
		np.savetxt('{}_{}_epoch_{}_average_train_loss.txt'.format(args.dataset, exp_name, s), [train_loss])
		np.savetxt('{}_{}_epoch_{}_average_train_error.txt'.format(args.dataset, exp_name, s), [train_error])

		if s % args.save_results == 0 or s == args.final_epoch:
			val_error, val_mIOU = test(FVI, val_loader, num_classes, args.dataset, exp_name, plot_imgs=False)
			print('Epoch: {} || Validation Error: {:.5f} || Validation Mean IOU: {:.5f}'.format(s, val_error, val_mIOU))
			torch.save(FVI.state_dict(), 'model_{}_{}.bin'.format(args.dataset, exp_name))
			torch.save(optimizer.state_dict(), 'optimizer_{}_{}.bin'.format(args.dataset, exp_name))
			np.savetxt('{}_{}_epoch_{}_val_error.txt'.format(args.dataset, exp_name, s), [val_error])
			np.savetxt('{}_{}_epoch_{}_val_mIOU.txt'.format(args.dataset, exp_name, s), [val_mIOU])
コード例 #2
0
ファイル: run_mcd_seg.py プロジェクト: zebrajack/FVI_CV
def train(num_epochs, train_loader):
	model.train()

	ft_start_flag = 0
	for s in range(args.start_epoch, args.start_epoch + args.n_epochs + 1):
		train_loss = 0.
		train_error = 0.
		model.train()
		if s >= args.ft_start and ft_start_flag == 0:
			print('Now fine-tuning on full size images')
			train_loader = ft_loader
			ft_start_flag += 1
		for X, Y in tqdm(train_loader):
			x_t = X.to(device)
			y_t = Y.to(device)
			N_t = x_t.size(0)

			optimizer.zero_grad()

			mean, logvar_aleatoric = model(x_t)

			rescaled_logits = mean * torch.exp(-logvar_aleatoric)
			loss = - loss_seg(y_t, rescaled_logits, print_loss=True)

			loss.backward()
			train_loss += -loss.item()
			optimizer.step()
			_, _, train_acc_curr = numpy_metrics(rescaled_logits.argmax(1).data.cpu().view(N_t, -1).numpy(), y_t.view(N_t, -1).data.cpu().numpy())
			train_error += 1 - train_acc_curr
			adjust_learning_rate(args.lr, 0.998, optimizer, s, args.final_epoch)
			del x_t, y_t, mean, logvar_aleatoric
		train_loss /= len(train_loader)
		train_error /= len(train_loader)
		print('Epoch: {} || Average Train Error: {:.5f} || Average Train Loss: {:.5f}'.format(s, train_error, train_loss))
		np.savetxt('{}_{}_epoch_{}_average_train_loss.txt'.format(args.dataset, exp_name, s), [train_loss])
		np.savetxt('{}_{}_epoch_{}_average_train_error.txt'.format(args.dataset, exp_name, s), [train_error])

		if s % args.save_results == 0 or s == args.final_epoch:
			val_error, val_mIOU = test(model, val_loader, num_classes, args.dataset, exp_name, plot_imgs=False)
			print('Epoch: {} || Validation Error: {:.5f} || Validation Mean IOU: {:.5f}'.format(s, val_error, val_mIOU))
			torch.save(model.state_dict(), 'model_{}_{}.bin'.format(args.dataset, exp_name))
			torch.save(optimizer.state_dict(), 'optimizer_{}_{}.bin'.format(args.dataset, exp_name))
			np.savetxt('{}_{}_epoch_{}_val_error.txt'.format(args.dataset, exp_name, s), [val_error])
			np.savetxt('{}_{}_epoch_{}_val_mIOU.txt'.format(args.dataset, exp_name, s), [val_mIOU])
コード例 #3
0
	keys = ('device', 'x_inducing_var', 'f_prior', 'n_inducing', 'add_cov_diag', 'standard_cross_entropy')
	values = (device, args.x_inducing_var, args.f_prior, args.n_inducing, args.add_cov_diag, args.standard_cross_entropy)
	fvi_args = dict(zip(keys, values))

	FVI = FVI_seg(x_size=(H_crop, W_crop), num_classes=num_classes, **fvi_args).to(device)
	optimizer = torch.optim.SGD(FVI.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)

	if args.load:
		model_load_dir = os.path.join(args.base_dir, 'FVI_CV/model_{}_{}.bin'.format(args.dataset, exp_name))
		optimizer_load_dir = os.path.join(args.base_dir, 'FVI_CV/optimizer_{}_{}.bin'.format(args.dataset, exp_name)) 
		FVI.load_state_dict(torch.load(model_load_dir))
		optimizer.load_state_dict(torch.load(optimizer_load_dir))
		print('Loading FVI segmentation model..')

	if args.training_mode:
		print('Training FVI segmentation for {} epochs'.format(args.n_epochs))
		train(args.n_epochs, train_loader, FVI)
	if args.test_mode:
		print('Evaluating FVI segmentation on test set')
		model_load_dir = os.path.join(args.base_dir, 'FVI_CV/models_test/model_{}_fvi_seg_test.bin'.format(args.dataset))
		FVI.load_state_dict(torch.load(model_load_dir))
		error, mIOU = test(FVI, test_loader, num_classes, args.dataset, exp_name, mkdir=True)
		print('Test Error: {:.5f} || Test Mean IOU: {:.5f}'.format(error, mIOU))
		np.savetxt('{}_{}_epoch_{}_test_error.txt'.format(args.dataset, exp_name, -1), [error])
		np.savetxt('{}_{}_epoch_{}_test_mIOU.txt'.format(args.dataset, exp_name, -1), [mIOU])
	if args.test_runtime_mode:
		model_load_dir = os.path.join(args.base_dir, 'FVI_CV/models_test/model_{}_fvi_seg_test.bin'.format(args.dataset))
		FVI.load_state_dict(torch.load(model_load_dir))
		run_runtime_seg(FVI, test_loader, exp_name, 50)
コード例 #4
0
    if args.training_mode:
        print('Training determinstic segmentation for {} epochs'.format(
            args.n_epochs))
        train(args.n_epochs, train_loader)
    if args.test_mode:
        print('Evaluating {} segmentation on test set'.format(args.model_type))
        model_load_dir = os.path.join(
            args.base_dir,
            'FVI_CV/models_test/model_{}_mcd_seg_test.bin'.format(
                args.dataset))
        model.load_state_dict(torch.load(model_load_dir))
        error, mIOU = test(model,
                           test_loader,
                           num_classes,
                           args.dataset,
                           exp_name,
                           plot_imgs=True,
                           mkdir=True)
        print('Test Accuracy: {:.5f} || Test Mean IOU: {:.5f}'.format(
            1. - error, mIOU))
        np.savetxt(
            '{}_{}_epoch_{}_test_accuracy.txt'.format(args.dataset, exp_name,
                                                      -1), [1. - error])
        np.savetxt(
            '{}_{}_epoch_{}_test_mIOU.txt'.format(args.dataset, exp_name, -1),
            [mIOU])
    if args.test_runtime_mode:
        model_load_dir = os.path.join(
            args.base_dir,
            'FVI_CV/models_test/model_{}_mcd_seg_test.bin'.format(