def __init__(self, context: DeepSpeedTrialContext) -> None:
        self.context = context
        self.hparams = AttrDict(self.context.get_hparams())
        self.data_config = AttrDict(self.context.get_data_config())
        self.logger = TorchWriter()
        num_channels = data.CHANNELS_BY_DATASET[self.data_config.dataset]
        gen_net = Generator(
            self.hparams.generator_width_base, num_channels, self.hparams.noise_length
        )
        gen_net.apply(weights_init)
        disc_net = Discriminator(self.hparams.discriminator_width_base, num_channels)
        disc_net.apply(weights_init)
        gen_parameters = filter(lambda p: p.requires_grad, gen_net.parameters())
        disc_parameters = filter(lambda p: p.requires_grad, disc_net.parameters())
        ds_config = overwrite_deepspeed_config(
            self.hparams.deepspeed_config, self.hparams.get("overwrite_deepspeed_args", {})
        )
        generator, _, _, _ = deepspeed.initialize(
            model=gen_net, model_parameters=gen_parameters, config=ds_config
        )
        discriminator, _, _, _ = deepspeed.initialize(
            model=disc_net, model_parameters=disc_parameters, config=ds_config
        )

        self.generator = self.context.wrap_model_engine(generator)
        self.discriminator = self.context.wrap_model_engine(discriminator)
        self.fixed_noise = self.context.to_device(
            torch.randn(
                self.context.train_micro_batch_size_per_gpu, self.hparams.noise_length, 1, 1
            )
        )
        self.criterion = nn.BCELoss()
        # TODO: Test fp16
        self.fp16 = generator.fp16_enabled()
        self.gradient_accumulation_steps = generator.gradient_accumulation_steps()
        # Manually perform gradient accumulation.
        if self.gradient_accumulation_steps > 1:
            logging.info("Disabling automatic gradient accumulation.")
            self.context.disable_auto_grad_accumulation()
Example #2
0
def train(args):
    writer = SummaryWriter(log_dir=args.tensorboard_path)
    create_folder(args.outf)
    set_seed(args.manualSeed)
    cudnn.benchmark = True
    dataset, nc = get_dataset(args)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=int(args.workers))
    torch.cuda.set_device(args.local_rank)
    device = torch.device(
        "cuda",
        args.local_rank)  #torch.device("cuda:0" if args.cuda else "cpu")
    ngpu = 0
    nz = int(args.nz)
    ngf = int(args.ngf)
    ndf = int(args.ndf)

    netG = Generator(ngpu, ngf, nc, nz).to(device)
    netG.apply(weights_init)
    if args.netG != '':
        netG.load_state_dict(torch.load(args.netG))

    netD = Discriminator(ngpu, ndf, nc).to(device)
    netD.apply(weights_init)
    if args.netD != '':
        netD.load_state_dict(torch.load(args.netD))

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    # setup optimizer
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, 0.999))

    model_engineD, optimizerD, _, _ = deepspeed.initialize(
        args=args,
        model=netD,
        model_parameters=netD.parameters(),
        optimizer=optimizerD)
    model_engineG, optimizerG, _, _ = deepspeed.initialize(
        args=args,
        model=netG,
        model_parameters=netG.parameters(),
        optimizer=optimizerG)

    torch.cuda.synchronize()
    start = time()
    for epoch in range(args.epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real = data[0].to(device)
            batch_size = real.size(0)
            label = torch.full((batch_size, ),
                               real_label,
                               dtype=real.dtype,
                               device=device)
            output = netD(real)
            errD_real = criterion(output, label)
            model_engineD.backward(errD_real)
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            model_engineD.backward(errD_fake)
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            #optimizerD.step() # alternative (equivalent) step
            model_engineD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            model_engineG.backward(errG)
            D_G_z2 = output.mean().item()
            #optimizerG.step() # alternative (equivalent) step
            model_engineG.step()

            print(
                '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                % (epoch, args.epochs, i, len(dataloader), errD.item(),
                   errG.item(), D_x, D_G_z1, D_G_z2))
            writer.add_scalar("Loss_D", errD.item(),
                              epoch * len(dataloader) + i)
            writer.add_scalar("Loss_G", errG.item(),
                              epoch * len(dataloader) + i)
            if i % 100 == 0:
                vutils.save_image(real,
                                  '%s/real_samples.png' % args.outf,
                                  normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                                  '%s/fake_samples_epoch_%03d.png' %
                                  (args.outf, epoch),
                                  normalize=True)

        # do checkpointing
        #torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch))
        #torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch))
    torch.cuda.synchronize()
    stop = time()
    print(
        f"total wall clock time for {args.epochs} epochs is {stop-start} secs")