Ejemplo n.º 1
0
    def __init__(self, context: DeepSpeedTrialContext) -> None:
        self.context = context
        self.hparams = AttrDict(self.context.get_hparams())
        self.data_config = AttrDict(self.context.get_data_config())
        self.logger = TorchWriter()
        num_channels = data.CHANNELS_BY_DATASET[self.data_config.dataset]
        gen_net = Generator(
            self.hparams.generator_width_base, num_channels, self.hparams.noise_length
        )
        gen_net.apply(weights_init)
        disc_net = Discriminator(self.hparams.discriminator_width_base, num_channels)
        disc_net.apply(weights_init)
        gen_parameters = filter(lambda p: p.requires_grad, gen_net.parameters())
        disc_parameters = filter(lambda p: p.requires_grad, disc_net.parameters())
        ds_config = overwrite_deepspeed_config(
            self.hparams.deepspeed_config, self.hparams.get("overwrite_deepspeed_args", {})
        )
        generator, _, _, _ = deepspeed.initialize(
            model=gen_net, model_parameters=gen_parameters, config=ds_config
        )
        discriminator, _, _, _ = deepspeed.initialize(
            model=disc_net, model_parameters=disc_parameters, config=ds_config
        )

        self.generator = self.context.wrap_model_engine(generator)
        self.discriminator = self.context.wrap_model_engine(discriminator)
        self.fixed_noise = self.context.to_device(
            torch.randn(
                self.context.train_micro_batch_size_per_gpu, self.hparams.noise_length, 1, 1
            )
        )
        self.criterion = nn.BCELoss()
        # TODO: Test fp16
        self.fp16 = generator.fp16_enabled()
        self.gradient_accumulation_steps = generator.gradient_accumulation_steps()
        # Manually perform gradient accumulation.
        if self.gradient_accumulation_steps > 1:
            logging.info("Disabling automatic gradient accumulation.")
            self.context.disable_auto_grad_accumulation()
Ejemplo n.º 2
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    if not os.path.exists(args.figure_path):
        os.makedirs(args.figure_path)

    # Image preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models (Gen)
    generator = Generator(args.embed_size, args.hidden_size, len(vocab),
                          args.num_layers)

    # Build the models (Disc)
    discriminator = Discriminator(args.embed_size, args.hidden_size,
                                  len(vocab), args.num_layers)

    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()

    # Loss and Optimizer (Gen)
    mle_criterion = nn.CrossEntropyLoss()
    params_gen = list(generator.parameters())
    optimizer_gen = torch.optim.Adam(params_gen)

    # Loss and Optimizer (Disc)
    params_disc = list(discriminator.parameters())
    optimizer_disc = torch.optim.Adam(params_disc)

    if int(args.pretraining) == 1:
        # Pre-training: train generator with MLE and discriminator with 3 losses (real + fake + wrong)
        total_steps = len(data_loader)
        print(total_steps)
        disc_losses = []
        gen_losses = []
        print('pre-training')
        generator.load_state_dict(torch.load(args.pretrained_gen_path))
        discriminator.load_state_dict(torch.load(args.pretrained_disc_path))
        for epoch in range(
                max([
                    int(args.gen_pretrain_num_epochs),
                    int(args.disc_pretrain_num_epochs)
                ])):
            if epoch < 5:
                continue
        # for epoch in range(max([int(args.gen_pretrain_num_epochs), int(args.disc_pretrain_num_epochs)])):
            for i, (images, captions, lengths, wrong_captions,
                    wrong_lengths) in enumerate(data_loader):
                images = to_var(images, volatile=True)
                captions = to_var(captions)
                wrong_captions = to_var(wrong_captions)
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                if epoch < int(args.gen_pretrain_num_epochs):
                    generator.zero_grad()
                    outputs, _ = generator(images, captions, lengths)
                    loss_gen = mle_criterion(outputs, targets)
                    # gen_losses.append(loss_gen.cpu().data.numpy()[0])
                    loss_gen.backward()
                    optimizer_gen.step()

                if epoch < int(args.disc_pretrain_num_epochs):
                    discriminator.zero_grad()
                    rewards_real = discriminator(images, captions, lengths)
                    # rewards_fake = discriminator(images, sampled_captions, sampled_lengths)
                    rewards_wrong = discriminator(images, wrong_captions,
                                                  wrong_lengths)
                    real_loss = -torch.mean(torch.log(rewards_real))
                    # fake_loss = -torch.mean(torch.clamp(torch.log(1 - rewards_fake), min=-1000))
                    wrong_loss = -torch.mean(
                        torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
                    loss_disc = real_loss + wrong_loss  # + fake_loss, no fake_loss because this is pretraining

                    # disc_losses.append(loss_disc.cpu().data.numpy()[0])
                    loss_disc.backward()
                    optimizer_disc.step()
                if (i + 1) % args.log_step == 0:
                    print(
                        'Epoch [%d], Step [%d], Disc Loss: %.4f, Gen Loss: %.4f'
                        % (epoch + 1, i + 1, loss_disc, loss_gen))
                if (i + 1) % 500 == 0:
                    torch.save(
                        discriminator.state_dict(),
                        os.path.join(
                            args.model_path,
                            'pretrained-discriminator-%d-%d.pkl' %
                            (int(epoch) + 1, i + 1)))
                    torch.save(
                        generator.state_dict(),
                        os.path.join(
                            args.model_path, 'pretrained-generator-%d-%d.pkl' %
                            (int(epoch) + 1, i + 1)))

                    # Save pretrained models
        torch.save(
            discriminator.state_dict(),
            os.path.join(
                args.model_path, 'pretrained-discriminator-%d.pkl' %
                int(args.disc_pretrain_num_epochs)))
        torch.save(
            generator.state_dict(),
            os.path.join(
                args.model_path, 'pretrained-generator-%d.pkl' %
                int(args.gen_pretrain_num_epochs)))

        # Plot pretraining figures
        # plt.plot(disc_losses, label='pretraining_disc_loss')
        # plt.savefig(args.figure_path + 'pretraining_disc_losses.png')
        # plt.clf()
        #
        # plt.plot(gen_losses, label='pretraining_gen_loss')
        # plt.savefig(args.figure_path + 'pretraining_gen_losses.png')
        # plt.clf()

    else:
        generator.load_state_dict(torch.load(args.pretrained_gen_path))
        discriminator.load_state_dict(torch.load(args.pretrained_disc_path))

    # # Skip the rest for now
    # return

    # Train the Models
    total_step = len(data_loader)
    disc_gan_losses = []
    gen_gan_losses = []
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths, wrong_captions,
                wrong_lengths) in enumerate(data_loader):

            # Set mini-batch dataset
            images = to_var(images, volatile=True)
            captions = to_var(captions)
            wrong_captions = to_var(wrong_captions)

            generator.zero_grad()
            outputs, packed_lengths = generator(images, captions, lengths)
            outputs = PackedSequence(outputs, packed_lengths)
            outputs = pad_packed_sequence(outputs,
                                          batch_first=True)  # (b, T, V)

            Tmax = outputs[0].size(1)
            if torch.cuda.is_available():
                rewards = torch.zeros_like(outputs[0]).type(
                    torch.cuda.FloatTensor)
            else:
                rewards = torch.zeros_like(outputs[0]).type(torch.FloatTensor)

            # getting rewards from disc

# for t in tqdm(range(2, Tmax, 4)):
            for t in range(2, Tmax, 2):
                # for t in range(2, 4):
                if t >= min(
                        lengths
                ):  # TODO this makes things easier, but could min(lengths) could be too short
                    break

                gen_samples = to_var(torch.zeros(
                    (captions.size(0), Tmax)).type(torch.FloatTensor),
                                     volatile=True)
                # part 1: taken from real caption
                gen_samples[:, :t] = captions[:, :t].data

                predicted_ids, saved_states = generator.pre_compute(
                    gen_samples, t)
                # for v in range(predicted_ids.size(1)):
                v = predicted_ids
                # pdb.set_trace()
                # part 2: taken from all possible vocabs
                # gen_samples[:,t] = predicted_ids[:,v]
                gen_samples[:, t] = v
                # part 3: taken from rollouts
                gen_samples[:, t:] = generator.rollout(gen_samples, t,
                                                       saved_states)

                sampled_lengths = []
                # finding sampled_lengths
                for batch in range(int(captions.size(0))):
                    for b_t in range(Tmax):
                        if gen_samples[batch,
                                       b_t].cpu().data.numpy() == 2:  # <end>
                            sampled_lengths.append(b_t + 1)
                            break
                        elif b_t == Tmax - 1:
                            sampled_lengths.append(Tmax)

                # sort sampled_lengths
                sampled_lengths = np.array(sampled_lengths)
                sampled_lengths[::-1].sort()
                sampled_lengths = sampled_lengths.tolist()

                # get rewards from disc
                rewards[:, t, v] = discriminator(images, gen_samples.detach(),
                                                 sampled_lengths)

            # rewards = rewards.detach()
            # pdb.set_trace()
            rewards_detached = rewards.data
            rewards_detached = to_var(rewards_detached)

            loss_gen = torch.dot(outputs[0], -rewards_detached)
            # gen_gan_losses.append(loss_gen.cpu().data.numpy()[0])
            # pdb.set_trace()

            loss_gen.backward()
            optimizer_gen.step()

            # TODO get sampled_captions
            sampled_ids = generator.sample(images)
            # sampled_captions = torch.zeros_like(sampled_ids).type(torch.LongTensor)
            sampled_lengths = []
            # finding sampled_lengths
            for batch in range(int(captions.size(0))):
                for b_t in range(20):
                    #pdb.set_trace()
                    #sampled_captions[batch, b_t].data = sampled_ids[batch, b_t].cpu().data.numpy()[0]
                    if sampled_ids[batch,
                                   b_t].cpu().data.numpy() == 2:  # <end>
                        sampled_lengths.append(b_t + 1)
                        break
                    elif b_t == 20 - 1:
                        sampled_lengths.append(20)
            # sort sampled_lengths
            sampled_lengths = np.array(sampled_lengths)
            sampled_lengths[::-1].sort()
            sampled_lengths = sampled_lengths.tolist()

            # Train discriminator
            discriminator.zero_grad()
            images.volatile = False
            captions.volatile = False
            wrong_captions.volatile = False
            rewards_real = discriminator(images, captions, lengths)
            rewards_fake = discriminator(images, sampled_ids, sampled_lengths)
            rewards_wrong = discriminator(images, wrong_captions,
                                          wrong_lengths)
            real_loss = -torch.mean(torch.log(rewards_real))
            fake_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_fake), min=-1000))
            wrong_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
            loss_disc = real_loss + fake_loss + wrong_loss

            # disc_gan_losses.append(loss_disc.cpu().data.numpy()[0])
            loss_disc.backward()
            optimizer_disc.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Step [%d/%d], Disc Loss: %.4f, Gen Loss: %.4f'
                    % (epoch, args.num_epochs, i, total_step, loss_disc,
                       loss_gen))

            # Save the models
            # if (i+1) % args.save_step == 0:
            if (
                    i + 1
            ) % args.log_step == 0:  # jm: saving at the last iteration instead
                torch.save(
                    generator.state_dict(),
                    os.path.join(
                        args.model_path,
                        'generator-gan-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    discriminator.state_dict(),
                    os.path.join(
                        args.model_path,
                        'discriminator-gan-%d-%d.pkl' % (epoch + 1, i + 1)))
Ejemplo n.º 3
0
def train(args):
    writer = SummaryWriter(log_dir=args.tensorboard_path)
    create_folder(args.outf)
    set_seed(args.manualSeed)
    cudnn.benchmark = True
    dataset, nc = get_dataset(args)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=int(args.workers))
    torch.cuda.set_device(args.local_rank)
    device = torch.device(
        "cuda",
        args.local_rank)  #torch.device("cuda:0" if args.cuda else "cpu")
    ngpu = 0
    nz = int(args.nz)
    ngf = int(args.ngf)
    ndf = int(args.ndf)

    netG = Generator(ngpu, ngf, nc, nz).to(device)
    netG.apply(weights_init)
    if args.netG != '':
        netG.load_state_dict(torch.load(args.netG))

    netD = Discriminator(ngpu, ndf, nc).to(device)
    netD.apply(weights_init)
    if args.netD != '':
        netD.load_state_dict(torch.load(args.netD))

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    # setup optimizer
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, 0.999))

    model_engineD, optimizerD, _, _ = deepspeed.initialize(
        args=args,
        model=netD,
        model_parameters=netD.parameters(),
        optimizer=optimizerD)
    model_engineG, optimizerG, _, _ = deepspeed.initialize(
        args=args,
        model=netG,
        model_parameters=netG.parameters(),
        optimizer=optimizerG)

    torch.cuda.synchronize()
    start = time()
    for epoch in range(args.epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real = data[0].to(device)
            batch_size = real.size(0)
            label = torch.full((batch_size, ),
                               real_label,
                               dtype=real.dtype,
                               device=device)
            output = netD(real)
            errD_real = criterion(output, label)
            model_engineD.backward(errD_real)
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            model_engineD.backward(errD_fake)
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            #optimizerD.step() # alternative (equivalent) step
            model_engineD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            model_engineG.backward(errG)
            D_G_z2 = output.mean().item()
            #optimizerG.step() # alternative (equivalent) step
            model_engineG.step()

            print(
                '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                % (epoch, args.epochs, i, len(dataloader), errD.item(),
                   errG.item(), D_x, D_G_z1, D_G_z2))
            writer.add_scalar("Loss_D", errD.item(),
                              epoch * len(dataloader) + i)
            writer.add_scalar("Loss_G", errG.item(),
                              epoch * len(dataloader) + i)
            if i % 100 == 0:
                vutils.save_image(real,
                                  '%s/real_samples.png' % args.outf,
                                  normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                                  '%s/fake_samples_epoch_%03d.png' %
                                  (args.outf, epoch),
                                  normalize=True)

        # do checkpointing
        #torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch))
        #torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch))
    torch.cuda.synchronize()
    stop = time()
    print(
        f"total wall clock time for {args.epochs} epochs is {stop-start} secs")
Ejemplo n.º 4
0
def train_gan():

    batch_size = 64
    epochs = 100
    disc_update = 1
    gen_update = 5
    latent_dimension = 100
    lambduh = 10

    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    # load data
    train_loader, valid_loader, test_loader = get_data_loader(
        'data', batch_size)

    disc_model = Discriminator().to(device)
    gen_model = Generator(latent_dimension).to(device)

    disc_optim = Adam(disc_model.parameters(), lr=1e-4, betas=(0.5, 0.9))
    gen_optim = Adam(gen_model.parameters(), lr=1e-4, betas=(0.5, 0.9))

    for e in range(epochs):
        disc_loss = 0
        gen_loss = 0
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            b_size = images.shape[0]
            step = i + 1
            if step % disc_update == 0:
                disc_model.zero_grad()
                # sample noise
                noise = torch.randn((b_size, latent_dimension), device=device)

                # loss on fake
                inputs = gen_model(noise).detach()
                f_outputs = disc_model(inputs)
                loss = f_outputs.mean()

                # loss on real
                r_outputs = disc_model(images)
                loss -= r_outputs.mean()

                # add gradient penalty
                loss += lambduh * gradient_penalty(disc_model, images, inputs,
                                                   device)

                disc_loss += loss
                loss.backward()
                disc_optim.step()

            if step % gen_update == 0:
                gen_model.zero_grad()

                noise = torch.randn((b_size, latent_dimension)).to(device)
                inputs = gen_model(noise)
                outputs = disc_model(inputs)
                loss = -outputs.mean()

                gen_loss += loss
                loss.backward()
                gen_optim.step()

        torch.save(
            {
                'epoch': e,
                'disc_model': disc_model.state_dict(),
                'gen_model': gen_model.state_dict(),
                'disc_loss': disc_loss,
                'gen_loss': gen_loss,
                'disc_optim': disc_optim.state_dict(),
                'gen_optim': gen_optim.state_dict()
            }, "upsample/checkpoint_{}.pth".format(e))
        print("Epoch: {} Disc loss: {}".format(
            e + 1,
            disc_loss.item() / len(train_loader)))
        print("Epoch: {} Gen loss: {}".format(
            e + 1,
            gen_loss.item() / len(train_loader)))
Ejemplo n.º 5
0
netG = Generator(d, latent_size)
netD = Discriminator(d)
if cuda:
    netG = netG.cuda()
    netD = netD.cuda()
#print(netG)
#summary(netG, (1, 128))
# print(netD)
# summary(netD, (3, 64, 64))
# exit()

criterion = nn.BCELoss()

optimizerG = optim.Adam(netG.parameters(), lr=0.002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.002, betas=(0.5, 0.999))

fix_noise = torch.randn(batch_size, latent_size, 1, 1).cuda()
lossG = []
lossD = []
Dx = []
DG1 = []
DG2 = []
for epoch in range(epoch_num):
    for i, data in enumerate(train_dataloader):
        # Update D network
        # train with real
        if cuda:
            data = data.cuda()
        netD.zero_grad()
        label = torch.full((data.size(0), ), 1).cuda()
Ejemplo n.º 6
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models (Gen)
    # TODO: put these in generator
    encoder = EncoderCNN(args.embed_size)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers)

    # Build the models (Disc)
    discriminator = Discriminator(args.embed_size, args.hidden_size,
                                  len(vocab), args.num_layers)

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

    # Loss and Optimizer (Gen)
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Loss and Optimizer (Disc)
    params_disc = list(discriminator.parameters())
    optimizer_disc = torch.optim.Adam(params_disc)

    # Train the Models
    total_step = len(data_loader)
    disc_losses = []
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths, wrong_captions,
                wrong_lengths) in enumerate(data_loader):

            # pdb.set_trace()
            # TODO: train disc before gen

            # Set mini-batch dataset
            images = to_var(images, volatile=True)
            captions = to_var(captions)
            wrong_captions = to_var(wrong_captions)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, Backward and Optimize
            decoder.zero_grad()
            encoder.zero_grad()
            features = encoder(images)
            outputs = decoder(features, captions, lengths)

            sampled_captions = decoder.sample(features)
            # sampled_captions = torch.zeros_like(sampled_ids)
            sampled_lengths = []

            for row in range(sampled_captions.size(0)):
                for index, word_id in enumerate(sampled_captions[row, :]):
                    # pdb.set_trace()
                    word = vocab.idx2word[word_id.cpu().data.numpy()[0]]
                    # sampled_captions[row, index].data = word
                    if word == '<end>':
                        sampled_lengths.append(index + 1)
                        break
                    elif index == sampled_captions.size(1) - 1:
                        sampled_lengths.append(sampled_captions.size(1))
                        break
            sampled_lengths = np.array(sampled_lengths)
            sampled_lengths[::-1].sort()
            sampled_lengths = sampled_lengths.tolist()
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # Train discriminator
            discriminator.zero_grad()
            rewards_real = discriminator(images, captions, lengths)
            rewards_fake = discriminator(images, sampled_captions,
                                         sampled_lengths)
            rewards_wrong = discriminator(images, wrong_captions,
                                          wrong_lengths)
            real_loss = -torch.mean(torch.log(rewards_real))
            fake_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_fake), min=-1000))
            wrong_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
            loss_disc = real_loss + fake_loss + wrong_loss

            disc_losses.append(loss_disc.cpu().data.numpy()[0])
            loss_disc.backward()
            optimizer_disc.step()

            # print('iteration %i' % i)
            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
                    % (epoch, args.num_epochs, i, total_step, loss.data[0],
                       np.exp(loss.data[0])))

            # Save the models
            # if (i+1) % args.save_step == 0:
            if (
                    i + 1
            ) % total_step == 0:  # jm: saving at the last iteration instead
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(args.model_path,
                                 'encoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    discriminator.state_dict(),
                    os.path.join(
                        args.model_path,
                        'discriminator-%d-%d.pkl' % (epoch + 1, i + 1)))

                # plot at the end of every epoch
                plt.plot(disc_losses, label='disc loss')
                plt.savefig('disc_losses.png')
                plt.clf()