示例#1
0
    train_dataset = CocoStuffDataSet(mode='train',
                                     supercategories=['animal'],
                                     height=HEIGHT,
                                     width=WIDTH,
                                     do_normalize=False)
    val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False)
    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
    NUM_CLASSES = train_dataset.numClasses
    print("Number of classes: {}".format(NUM_CLASSES))
    image_shape = (3, HEIGHT, WIDTH)
    segmentation_shape = (NUM_CLASSES, HEIGHT, WIDTH)

    discriminator = None
    generator = get_generator(args.generator_name, NUM_CLASSES, args.use_bn)
    if args.train_gan:
        discriminator = GAN(NUM_CLASSES, segmentation_shape, image_shape)
    trainer = Trainer(generator, discriminator, train_loader, val_loader, \
                    gan_reg=args.gan_reg, weight_clip=args.weight_clip, grad_clip=args.grad_clip, \
                    noise_scale=args.noise_scale, disc_lr=args.disc_lr, gen_lr=args.gen_lr, train_gan= args.train_gan, \
                    experiment_dir=experiment_dir, resume=args.load_model, load_iter=args.load_iter)

    if args.mode == "train":
        trainer.train(num_epochs=args.epochs,
                      print_every=args.print_every,
                      eval_every=args.eval_every)
    elif args.mode == 'eval':
        assert (args.load_model), "Need to load model to evaluate it"
        # just do evaluation
        print(trainer.get_confusion_matrix(val_loader))
        print('mIOU {}'.format(trainer.evaluate_meanIOU(val_loader)))