train_dataset = CocoStuffDataSet(mode='train', supercategories=['animal'], height=HEIGHT, width=WIDTH, do_normalize=False) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True) NUM_CLASSES = train_dataset.numClasses print("Number of classes: {}".format(NUM_CLASSES)) image_shape = (3, HEIGHT, WIDTH) segmentation_shape = (NUM_CLASSES, HEIGHT, WIDTH) discriminator = None generator = get_generator(args.generator_name, NUM_CLASSES, args.use_bn) if args.train_gan: discriminator = GAN(NUM_CLASSES, segmentation_shape, image_shape) trainer = Trainer(generator, discriminator, train_loader, val_loader, \ gan_reg=args.gan_reg, weight_clip=args.weight_clip, grad_clip=args.grad_clip, \ noise_scale=args.noise_scale, disc_lr=args.disc_lr, gen_lr=args.gen_lr, train_gan= args.train_gan, \ experiment_dir=experiment_dir, resume=args.load_model, load_iter=args.load_iter) if args.mode == "train": trainer.train(num_epochs=args.epochs, print_every=args.print_every, eval_every=args.eval_every) elif args.mode == 'eval': assert (args.load_model), "Need to load model to evaluate it" # just do evaluation print(trainer.get_confusion_matrix(val_loader)) print('mIOU {}'.format(trainer.evaluate_meanIOU(val_loader)))