def __init__(self, context: DeepSpeedTrialContext) -> None: self.context = context self.hparams = AttrDict(self.context.get_hparams()) self.data_config = AttrDict(self.context.get_data_config()) self.logger = TorchWriter() num_channels = data.CHANNELS_BY_DATASET[self.data_config.dataset] gen_net = Generator( self.hparams.generator_width_base, num_channels, self.hparams.noise_length ) gen_net.apply(weights_init) disc_net = Discriminator(self.hparams.discriminator_width_base, num_channels) disc_net.apply(weights_init) gen_parameters = filter(lambda p: p.requires_grad, gen_net.parameters()) disc_parameters = filter(lambda p: p.requires_grad, disc_net.parameters()) ds_config = overwrite_deepspeed_config( self.hparams.deepspeed_config, self.hparams.get("overwrite_deepspeed_args", {}) ) generator, _, _, _ = deepspeed.initialize( model=gen_net, model_parameters=gen_parameters, config=ds_config ) discriminator, _, _, _ = deepspeed.initialize( model=disc_net, model_parameters=disc_parameters, config=ds_config ) self.generator = self.context.wrap_model_engine(generator) self.discriminator = self.context.wrap_model_engine(discriminator) self.fixed_noise = self.context.to_device( torch.randn( self.context.train_micro_batch_size_per_gpu, self.hparams.noise_length, 1, 1 ) ) self.criterion = nn.BCELoss() # TODO: Test fp16 self.fp16 = generator.fp16_enabled() self.gradient_accumulation_steps = generator.gradient_accumulation_steps() # Manually perform gradient accumulation. if self.gradient_accumulation_steps > 1: logging.info("Disabling automatic gradient accumulation.") self.context.disable_auto_grad_accumulation()
def train(args): writer = SummaryWriter(log_dir=args.tensorboard_path) create_folder(args.outf) set_seed(args.manualSeed) cudnn.benchmark = True dataset, nc = get_dataset(args) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchSize, shuffle=True, num_workers=int(args.workers)) torch.cuda.set_device(args.local_rank) device = torch.device( "cuda", args.local_rank) #torch.device("cuda:0" if args.cuda else "cpu") ngpu = 0 nz = int(args.nz) ngf = int(args.ngf) ndf = int(args.ndf) netG = Generator(ngpu, ngf, nc, nz).to(device) netG.apply(weights_init) if args.netG != '': netG.load_state_dict(torch.load(args.netG)) netD = Discriminator(ngpu, ndf, nc).to(device) netD.apply(weights_init) if args.netD != '': netD.load_state_dict(torch.load(args.netD)) criterion = nn.BCELoss() fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=device) real_label = 1 fake_label = 0 # setup optimizer optimizerD = torch.optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) model_engineD, optimizerD, _, _ = deepspeed.initialize( args=args, model=netD, model_parameters=netD.parameters(), optimizer=optimizerD) model_engineG, optimizerG, _, _ = deepspeed.initialize( args=args, model=netG, model_parameters=netG.parameters(), optimizer=optimizerG) torch.cuda.synchronize() start = time() for epoch in range(args.epochs): for i, data in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real netD.zero_grad() real = data[0].to(device) batch_size = real.size(0) label = torch.full((batch_size, ), real_label, dtype=real.dtype, device=device) output = netD(real) errD_real = criterion(output, label) model_engineD.backward(errD_real) D_x = output.mean().item() # train with fake noise = torch.randn(batch_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(fake_label) output = netD(fake.detach()) errD_fake = criterion(output, label) model_engineD.backward(errD_fake) D_G_z1 = output.mean().item() errD = errD_real + errD_fake #optimizerD.step() # alternative (equivalent) step model_engineD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost output = netD(fake) errG = criterion(output, label) model_engineG.backward(errG) D_G_z2 = output.mean().item() #optimizerG.step() # alternative (equivalent) step model_engineG.step() print( '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, args.epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) writer.add_scalar("Loss_D", errD.item(), epoch * len(dataloader) + i) writer.add_scalar("Loss_G", errG.item(), epoch * len(dataloader) + i) if i % 100 == 0: vutils.save_image(real, '%s/real_samples.png' % args.outf, normalize=True) fake = netG(fixed_noise) vutils.save_image(fake.detach(), '%s/fake_samples_epoch_%03d.png' % (args.outf, epoch), normalize=True) # do checkpointing #torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch)) #torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch)) torch.cuda.synchronize() stop = time() print( f"total wall clock time for {args.epochs} epochs is {stop-start} secs")