def run(args): dataset = RorschachWrapper() save_args(args) # save command line to a file for reference train_loader = rorschach_cgan_data_loader(args, dataset=dataset) # get the data # todo model = patchCWGANModel(args, discriminator=patchCDiscriminatorNetwork(args), generator=CGeneratorNetwork(args)) # Build trainer trainer = Trainer(model) trainer.build_criterion( CWGANDiscriminatorLoss(penalty_weight=args.penalty_weight, model=model)) trainer.build_optimizer('Adam', model.discriminator.parameters(), lr=args.discriminator_lr) trainer.save_every((1, 'epochs')) trainer.save_to_directory(args.save_directory) trainer.set_max_num_epochs(args.epochs) trainer.register_callback(CGenerateDataCallback(args, dataset=dataset)) trainer.register_callback( CGeneratorTrainingCallback(args, parameters=model.generator.parameters(), criterion=WGANGeneratorLoss(), dataset=dataset)) trainer.bind_loader('train', train_loader, num_inputs=2) # Custom logging configuration so it knows to log our images logger = TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every=(args.log_image_frequency, 'iteration')) trainer.build_logger(logger, log_directory=args.save_directory) logger.observe_state('generated_images') logger.observe_state('real_images') logger._trainer_states_being_observed_while_training.remove( 'training_inputs') if args.cuda: trainer.cuda() # Go! trainer.fit()
def run(args): save_args(args) # save command line to a file for reference train_loader = mnist_data_loader(args) # get the data model = GANModel(args, discriminator=DiscriminatorNetwork(args), generator=GeneratorNetwork(args)) # Build trainer trainer = Trainer(model) trainer.build_criterion( WGANDiscriminatorLoss(penalty_weight=args.penalty_weight, model=model)) trainer.build_optimizer('Adam', model.discriminator.parameters(), lr=args.discriminator_lr) trainer.save_every((1, 'epochs')) trainer.save_to_directory(args.save_directory) trainer.set_max_num_epochs(args.epochs) trainer.register_callback(GenerateDataCallback(args)) trainer.register_callback( GeneratorTrainingCallback(args, parameters=model.generator.parameters(), criterion=WGANGeneratorLoss())) trainer.bind_loader('train', train_loader) # Custom logging configuration so it knows to log our images logger = TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every=(args.log_image_frequency, 'iteration')) trainer.build_logger(logger, log_directory=args.save_directory) logger.observe_state('generated_images') logger.observe_state('real_images') logger._trainer_states_being_observed_while_training.remove( 'training_inputs') if args.cuda: trainer.cuda() # Go! trainer.fit() # Generate video from saved images if not args.no_ffmpeg: generate_video(args.save_directory)