Esempio n. 1
0
def create_model(image_size,
                 g_conv_dim=64,
                 d_conv_dim=64,
                 n_res_blocks=6,
                 device='cpu'):
    """Builds the generators and discriminators."""

    # Instantiate generators
    G_XtoY = CycleGenerator(image_size, g_conv_dim, n_res_blocks)
    G_YtoX = CycleGenerator(image_size, g_conv_dim, n_res_blocks)

    # Instantiate discriminators
    D_X = Discriminator(d_conv_dim)
    D_Y = Discriminator(d_conv_dim)

    # move models to GPU, if specified
    G_XtoY.to(device)
    G_YtoX.to(device)
    D_X.to(device)
    D_Y.to(device)

    # move models to GPU, if available
    #if torch.cuda.is_available():
    #    device = torch.device("cuda:0")
    #    G_XtoY.to(device)
    #    G_YtoX.to(device)
    #    D_X.to(device)
    #    D_Y.to(device)
    #    print('Models moved to GPU.')
    #else:
    #    print('Only CPU available.')

    return G_XtoY, G_YtoX, D_X, D_Y
def main():
    if CONFIG["CUDA"]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    weight_name = CONFIG["model"]["pretrained_weight"]
    model_dict = torch.load(weight_name)
    
    source_net = PoseEstimationWithMobileNet()
    target_net = PoseEstimationWithMobileNet()

    load_state(source_net, model_dict)
    load_state(target_net, model_dict)

    discriminator = Discriminator()
    criterion = nn.BCELoss()

    source_net = source_net.cuda(CONFIG["GPU"]["source_net"])
    target_net = target_net.cuda(CONFIG["GPU"]["target_net"])
    discriminator = discriminator.to(device)
    criterion = criterion.to(device)

    optimizer_tg = torch.optim.Adam(target_net.parameters(),
                                   lr=CONFIG["training_setting"]["t_lr"])
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                  lr=CONFIG["training_setting"]["d_lr"])

    dataset = ADDADataset()
    dataloader = DataLoader(dataset, CONFIG["dataset"]["batch_size"], shuffle=True, num_workers=0)

    trainer = Trainer(source_net, target_net, discriminator, 
                     dataloader, optimizer_tg, optimizer_d, criterion, device)
    trainer.train()
Esempio n. 3
0
def load_cyclegan_alignment(name, device):
    """
    Load trained models for entity alignment with a cycle GAN architecture.
    :param name: the name of the model directory
    :param device: the current torch device, used for transferring the saved models (which were possibly trained on a
    different device) to the correct device
    :return: the generator and discriminator models (subclasses of torch.nn.Module) and the training configurations
    for entity alignment with a cycle gan architecture
    """
    path = Path(MODEL_PATH) / name
    with open(path / "config.json", "r") as file:
        config = json.load(file)
    # Load Generator B->A
    with open(path / "generator_a_config.json", "r") as file:
        generator_a_config = json.load(file)
    generator_a = Generator(generator_a_config, device)
    generator_a.load_state_dict(
        load(path / "generator_a.pt", map_location=device))
    generator_a.to(device)
    # Load Generator A->B
    with open(path / "generator_b_config.json", "r") as file:
        generator_b_config = json.load(file)
    generator_b = Generator(generator_b_config, device)
    generator_b.load_state_dict(
        load(path / "generator_b.pt", map_location=device))
    generator_b.to(device)
    # Load Discriminator A
    with open(path / "discriminator_a_config.json", "r") as file:
        discriminator_a_config = json.load(file)
    discriminator_a = Discriminator(discriminator_a_config, device)
    discriminator_a.load_state_dict(
        load(path / "discriminator_a.pt", map_location=device))
    discriminator_a.to(device)
    # Load Discriminator B
    with open(path / "discriminator_b_config.json", "r") as file:
        discriminator_b_config = json.load(file)
    discriminator_b = Discriminator(discriminator_b_config, device)
    discriminator_b.load_state_dict(
        load(path / "discriminator_b.pt", map_location=device))
    discriminator_b.to(device)
    return generator_a, generator_b, discriminator_a, discriminator_b, config
Esempio n. 4
0
def main(args):
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    vocab_size = len(vocab)
    print('vocab_size:', vocab_size)

    dataloader = get_loader(image_dir, caption_path, vocab, 
                            args.batch_size,
                            crop_size,
                            shuffle=True, num_workers=num_workers)

    generator = Generator(attention_dim, embedding_size, lstm_size, vocab_size, load_path=args.g_path, noise=args.noise)
    generator = generator.to(device)
    generator = generator.train()

    discriminator = Discriminator(vocab_size, embedding_size, lstm_size, attention_dim, load_path=args.d_path)
    discriminator = discriminator.to(device)
    discriminator = discriminator.train()


    

    
    if args.train_mode == 'gd':
        for _ in range(5):
            for i in range(4):
                generator.pre_train(dataloader, vocab)
            for i in range(1):
                discriminator.fit(generator, dataloader, vocab)
    elif args.train_mode == 'dg':
        discriminator.fit(generator, dataloader, vocab)
        generator.pre_train(dataloader, vocab)
    elif args.train_mode == 'd':
        discriminator.fit(generator, dataloader, vocab)

    elif args.train_mode == 'g':
        generator.pre_train(dataloader, vocab)


        
    elif args.train_mode == 'ad':
        for i in range(5):
            generator.ad_train(dataloader, discriminator, vocab, gamma=args.gamma, update_every=args.update_every, alpha_c=1.0, num_rollouts=args.num_rollouts)
Esempio n. 5
0
def main(pretrain_dataset, rl_dataset, args):
    ##############################################################################
    # Setup
    ##############################################################################
    # set random seeds
    random.seed(const.SEED)
    np.random.seed(const.SEED)

    # load datasets
    pt_train_loader, pt_valid_loader = SplitDataLoader(
        pretrain_dataset, batch_size=const.BATCH_SIZE, drop_last=True).split()

    # Define Networks
    generator = Generator(const.VOCAB_SIZE, const.GEN_EMBED_DIM,
                          const.GEN_HIDDEN_DIM, device, args.cuda)
    discriminator = Discriminator(const.VOCAB_SIZE, const.DSCR_EMBED_DIM,
                                  const.DSCR_FILTER_LENGTHS,
                                  const.DSCR_NUM_FILTERS,
                                  const.DSCR_NUM_CLASSES, const.DSCR_DROPOUT)

    # if torch.cuda.device_count() > 1:
    # print("Using", torch.cuda.device_count(), "GPUs.")
    # generator = nn.DataParallel(generator)
    # discriminator = nn.DataParallel(discriminator)
    generator.to(device)
    discriminator.to(device)

    # set CUDA
    if args.cuda and torch.cuda.is_available():
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    ##############################################################################

    ##############################################################################
    # Pre-Training
    ##############################################################################
    # Pretrain and save Generator using MLE, Load the Pretrained generator and display training stats
    # if it already exists.
    print('#' * 80)
    print('Generator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_gen)) and op.exists(GEN_MODEL_CACHE):
        print('Loading Pretrained Generator ...')
        checkpoint = torch.load(GEN_MODEL_CACHE)
        generator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained for %d epochs.' %
              checkpoint['epochs'])
        print('::INFO:: Final Training Loss - %.5f' % checkpoint['train_loss'])
        print('::INFO:: Final Validation Loss - %.5f' %
              checkpoint['valid_loss'])
    else:
        try:
            print('Pretraining Generator with MLE ...')
            GeneratorPretrainer(generator, pt_train_loader, pt_valid_loader,
                                PT_CACHE_DIR, device, args).train()
        except KeyboardInterrupt:
            print('Stopped Generator Pretraining Early.')

    # Pretrain Discriminator on real data and data from the pretrained generator. If a pretrained Discriminator
    # already exists, load it and display its stats
    print('#' * 80)
    print('Discriminator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_dscr)) and op.exists(DSCR_MODEL_CACHE):
        print("Loading Pretrained Discriminator ...")
        checkpoint = torch.load(DSCR_MODEL_CACHE)
        discriminator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained on %d data generations.' %
              checkpoint['data_gens'])
        print('::INFO:: Model was trained for %d epochs per data generation.' %
              checkpoint['epochs_per_gen'])
        print('::INFO:: Final Loss - %.5f' % checkpoint['loss'])
    else:
        print('Pretraining Discriminator ...')
        try:
            DiscriminatorPretrainer(discriminator, rl_dataset, PT_CACHE_DIR,
                                    TEMP_DATA_DIR, device,
                                    args).train(generator)
        except KeyboardInterrupt:
            print('Stopped Discriminator Pretraining Early.')
    ##############################################################################

    ##############################################################################
    # Adversarial Training
    ##############################################################################
    print('#' * 80)
    print('Adversarial Training')
    print('#' * 80)
    AdversarialRLTrainer(generator, discriminator, rl_dataset, TEMP_DATA_DIR,
                         pt_valid_loader, device, args).train()
Esempio n. 6
0
train_loader = torch.utils.data.DataLoader(lsp_train_dataset, batch_size=args.batch_size, shuffle=True)
val_save_loader = torch.utils.data.DataLoader(lsp_val_dataset, batch_size=args.val_batch_size)
val_eval_loader = torch.utils.data.DataLoader(lsp_val_dataset, batch_size=args.val_batch_size, shuffle=True)
#train_eval = torch.utils.data.DataLoader(lsp_train_dataset, batch_size=args.val_batch_size, shuffle=True)


pck = metrics.PCK(metrics.Options(256, 8))




# Loading on GPU, if available
if (args.use_gpu):
    generator_model = generator_model.to(fast_device)
#    discriminator_model_conf = discriminator_model_conf.to(fast_device)
    discriminator_model_pose = discriminator_model_pose.to(fast_device)

# Cross entropy loss
#criterion = nn.CrossEntropyLoss()

# Setting the optimizer
if (args.optimizer_type == 'SGD'):
    optim_gen = optim.SGD(generator_model.parameters(), lr=args.lr, momentum=args.momentum)
    optim_disc = optim.SGD(discriminator_model.parameters(), lr=args.lr, momentum=args.momentum)

elif (args.optimizer_type == 'Adam'):
    optim_gen = optim.Adam(generator_model.parameters(), lr=args.lr) ## added the betas .originally not there 
#    optim_disc_conf = optim.Adam(discriminator_model_conf.parameters(), lr=args.lr) ##added the betas.originally not there 
    optim_disc_pose = optim.Adam(discriminator_model_pose.parameters(), lr=args.lr) ##added the betas.originally not there 
    
#----code added here inplementing rms-prob as an option for optimization--------#
Esempio n. 7
0
def main(args):
    # log hyperparameter
    print(args)

    # select device
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda: 0" if args.cuda else "cpu")

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # data loader
    transform = transforms.Compose([
        utils.Normalize(),
        utils.ToTensor()
    ])
    train_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_train_list,
        max_k=args.training_step,
        train=True,
        transform=transform
    )
    test_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_test_list,
        max_k=args.training_step,
        train=False,
        transform=transform
    )

    kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False, **kwargs)

    # model
    def generator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def discriminator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual)
    g_model.apply(generator_weights_init)
    if args.data_parallel and torch.cuda.device_count() > 1:
        g_model = nn.DataParallel(g_model)
    g_model.to(device)

    if args.gan_loss != "none":
        d_model = Discriminator(args.dis_sn)
        d_model.apply(discriminator_weights_init)
        # if args.dis_sn:
        #     d_model = add_sn(d_model)
        if args.data_parallel and torch.cuda.device_count() > 1:
            d_model = nn.DataParallel(d_model)
        d_model.to(device)

    mse_loss = nn.MSELoss()
    adversarial_loss = nn.MSELoss()
    train_losses, test_losses = [], []
    d_losses, g_losses = [], []

    # optimizer
    g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr,
                             betas=(args.beta1, args.beta2))
    if args.gan_loss != "none":
        d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr,
                                 betas=(args.beta1, args.beta2))

    Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor

    # load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint {}".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            g_model.load_state_dict(checkpoint["g_model_state_dict"])
            # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"])
            if args.gan_loss != "none":
                d_model.load_state_dict(checkpoint["d_model_state_dict"])
                # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"])
                d_losses = checkpoint["d_losses"]
                g_losses = checkpoint["g_losses"]
            train_losses = checkpoint["train_losses"]
            test_losses = checkpoint["test_losses"]
            print("=> load chekcpoint {} (epoch {})"
                  .format(args.resume, checkpoint["epoch"]))

    # main loop
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        # training..
        g_model.train()
        if args.gan_loss != "none":
            d_model.train()
        train_loss = 0.
        volume_loss_part = np.zeros(args.training_step)
        for i, sample in enumerate(train_loader):
            params = list(g_model.named_parameters())
            # pdb.set_trace()
            # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g)))
            # adversarial ground truths
            real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False)
            fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False)

            v_f = sample["v_f"].to(device)
            v_b = sample["v_b"].to(device)
            v_i = sample["v_i"].to(device)
            g_optimizer.zero_grad()
            fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)

            # adversarial loss
            # update discriminator
            if args.gan_loss != "none":
                avg_d_loss = 0.
                avg_d_loss_real = 0.
                avg_d_loss_fake = 0.
                for k in range(args.n_d):
                    d_optimizer.zero_grad()
                    decisions = d_model(v_i)
                    d_loss_real = adversarial_loss(decisions, real_label)
                    fake_decisions = d_model(fake_volumes.detach())

                    d_loss_fake = adversarial_loss(fake_decisions, fake_label)
                    d_loss = d_loss_real + d_loss_fake
                    d_loss.backward()
                    avg_d_loss += d_loss.item() / args.n_d
                    avg_d_loss_real += d_loss_real / args.n_d
                    avg_d_loss_fake += d_loss_fake / args.n_d

                    d_optimizer.step()

            # update generator
            if args.gan_loss != "none":
                avg_g_loss = 0.
            avg_loss = 0.
            for k in range(args.n_g):
                loss = 0.
                g_optimizer.zero_grad()

                # adversarial loss
                if args.gan_loss != "none":
                    fake_decisions = d_model(fake_volumes)
                    g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label)
                    loss += g_loss
                    avg_g_loss += g_loss.item() / args.n_g

                # volume loss
                if args.volume_loss:
                    volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes)
                    for j in range(v_i.shape[1]):
                        volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every
                    loss += volume_loss

                # feature loss
                if args.feature_loss:
                    feat_real = d_model.extract_features(v_i)
                    feat_fake = d_model.extract_features(fake_volumes)
                    for m in range(len(feat_real)):
                        loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m])

                avg_loss += loss / args.n_g
                loss.backward()
                g_optimizer.step()

            train_loss += avg_loss

            # log training status
            subEpoch = (i + 1) // args.log_every
            if (i+1) % args.log_every == 0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader),
                    avg_loss
                ))
                print("Volume Loss: ")
                for j in range(volume_loss_part.shape[0]):
                    print("\tintermediate {}: {:.6f}".format(
                        j+1, volume_loss_part[j]
                    ))

                if args.gan_loss != "none":
                    print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format(
                        avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss
                    ))
                    d_losses.append(avg_d_loss)
                    g_losses.append(avg_g_loss)
                # train_losses.append(avg_loss)
                train_losses.append(train_loss.item() / args.log_every)
                print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format(
                    subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time()))
                ))
                train_loss = 0.
                volume_loss_part = np.zeros(args.training_step)

            # testing...
            if (i + 1) % args.test_every == 0:
                g_model.eval()
                if args.gan_loss != "none":
                    d_model.eval()
                test_loss = 0.
                with torch.no_grad():
                    for i, sample in enumerate(test_loader):
                        v_f = sample["v_f"].to(device)
                        v_b = sample["v_b"].to(device)
                        v_i = sample["v_i"].to(device)
                        fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)
                        test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item()

                test_losses.append(test_loss * args.batch_size / len(test_loader.dataset))
                print("====> SubEpoch: {} Test set loss {:4f} Time {}".format(
                    subEpoch, test_losses[-1], time.asctime(time.localtime(time.time()))
                ))

            # saving...
            if (i+1) % args.check_every == 0:
                print("=> saving checkpoint at epoch {}".format(epoch))
                if args.gan_loss != "none":
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict":  g_optimizer.state_dict(),
                                "d_model_state_dict": d_model.state_dict(),
                                "d_optimizer_state_dict": d_optimizer.state_dict(),
                                "d_losses": d_losses,
                                "g_losses": g_losses,
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                else:
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict": g_optimizer.state_dict(),
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                torch.save(g_model.state_dict(),
                           os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth"))

        num_subEpoch = len(train_loader) // args.log_every
        print("====> Epoch: {} Average loss: {:.6f} Time {}".format(
            epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time()))
        ))
Esempio n. 8
0
# PRETRAIN DISCRIMINATOR
# random choose 1000 batch to train discriminator

print('\nStarting Discriminator Training...')
dis_optimizer = optim.Adagrad(dis.parameters())

optimizer = optim.Adam(gen.parameters(), lr=1e-4)

#optimizer = optim.Adagrad(dis.parameters())

#train_discriminator(dis, dis_optimizer, train_gen_batch, train_tar_batch, BATCH_SIZE, 3)
#torch.save(dis.state_dict(), './model/pretrain_discriminator/model.pt')

pretrained_dis_path = './model/pretrain_discriminator/model.pt'
dis.load_state_dict(torch.load(pretrained_dis_path))
dis.to(device)

#h = dis.init_hidden(10)
#dis(train_src_batch[0].transpose(0, 1).to(device), h)

# ADVERSARIAL TRAINING
print('\nStarting Adversarial Training...')

ADV_TRAIN_EPOCHS = 0
for epoch in range(ADV_TRAIN_EPOCHS):
    print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
    # TRAIN GENERATOR
    print('\nAdversarial Training Generator : ', end='')
    sys.stdout.flush()
    gen.train()
    train_generator_PG(gen, optimizer, dis, train_src_batch, train_src_lens,
# Dataset and the Dataloade
train_loader = torch.utils.data.DataLoader(lsp_train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True)
val_save_loader = torch.utils.data.DataLoader(lsp_val_dataset,
                                              batch_size=args.val_batch_size)
val_eval_loader = torch.utils.data.DataLoader(lsp_val_dataset,
                                              batch_size=args.val_batch_size,
                                              shuffle=True)

pck = metrics.PCK(metrics.Options(256, config['generator']['num_stacks']))

# Loading on GPU, if available
if (args.use_gpu):
    generator_model = generator_model.to(fast_device)
    discriminator_model = discriminator_model.to(fast_device)

# Cross entropy loss
criterion = nn.CrossEntropyLoss()

# Setting the optimizer
if (args.optimizer_type == 'SGD'):
    optim_gen = optim.SGD(generator_model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)
    optim_disc = optim.SGD(discriminator_model.parameters(),
                           lr=args.lr,
                           momentum=args.momentum)

elif (args.optimizer_type == 'Adam'):
    optim_gen = optim.Adam(generator_model.parameters(), lr=args.lr)
Esempio n. 10
0
def main():
    # データセットの準備
    make_data()
    train_img_list = make_datapath_list()
    mean, std = (0.5, ), (0.5, )
    train_dataset = GAN_Img_Dataset(train_img_list, ImageTransform(mean, std))
    batch_size = 64
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True)

    # モデルの定義と重み初期化
    G = Generator(z_dim=20, image_size=64)
    D = Discriminator(image_size=64)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    G.apply(weights_init)
    D.apply(weights_init)

    # decide device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:\t", device)
    G.to(device)
    D.to(device)

    # define optimizer
    g_lr, d_lr = 0.0001, 0.0004
    g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [0, 0.9])
    d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [0, 0.9])

    # パラメタ
    z_dim = 20  # 乱数の次元

    G.train()
    D.train()
    torch.backends.cudnn.benchmark = True

    num_train_imgs = len(train_dataloader.dataset)
    iteration = 1
    logs = []

    # 学習 (300 Epochs)
    for epoch in range(300):
        t_epoch_start = time.time()
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0

        for batch in train_dataloader:
            # バッチサイズ確認
            if batch.size()[0] == 1:
                continue

            # ラベルの準備
            batch = batch.to(device)
            batch_num = batch.size()[0]
            label_real = torch.full((batch_num, ), 1).to(device)
            label_fake = torch.full((batch_num, ), 0).to(device)

            # --- Discriminatorの学習 --- #
            # 真の画像を判定
            d_out_real, _, _ = D(batch)

            # 偽の画像を生成・判定
            input_z = torch.randn(batch_num, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images, _, _ = G(input_z)
            d_out_fake, _, _ = D(fake_images)

            # 損失を計算・パラメータ更新
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            d_loss = d_loss_real + d_loss_fake

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # --- Generatorの学習 --- #
            input_z = torch.randn(batch_num, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images, _, _ = G(input_z)
            d_out_fake, _, _ = D(fake_images)

            # 損失を計算・パラメータ更新
            g_loss = -d_out_fake.mean()
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            iteration += 1

        t_epoch_finish = time.time()
        print(
            'epoch {:3d}/300 || D_Loss: {:.4f} || G_Loss: {:.4f} || time: {:.4f} sec.'
            .format(epoch, epoch_d_loss / batch_size,
                    epoch_g_loss / batch_size, t_epoch_finish - t_epoch_start))

    # --- 画像生成・可視化する --- #
    test_size = 5  # 可視化する個数
    input_z = torch.randn(test_size, z_dim)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
    G.eval()
    fake_images, at_map1, at_map2 = G(input_z.to(device))

    fig = plt.figure(figsize=(15, 6))
    for i in range(0, 5):
        # top: fake image
        plt.subplot(2, 5, i + 1)
        plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

        # middle: atmap1
        plt.subplot(2, 5, i + 6)
        am = at_map1[i].view(16, 16, 16, 16)
        am = am[7][7]
        plt.imshow(am.cpu().detach().numpy(), 'Reds')

    plt.savefig('visualization.png')
Esempio n. 11
0
IMAGE_SIZE = 64
LEARNING_RATE = 0.0002
UPDATE_INTERVAL = 3
LOG_INTERVAL = 50

data_A, data_B = getCeleb('Male', -1, DOMAIN_A, DOMAIN_B)
test_A, test_B = getCeleb('Male', -1, DOMAIN_A, DOMAIN_B, True)

generator_A = Generator()
generator_B = Generator()
discriminator_A = Discriminator()
discriminator_B = Discriminator()

generator_A = generator_A.to(device)
generator_B = generator_B.to(device)
discriminator_A = discriminator_A.to(device)
discriminator_B = discriminator_B.to(device)

if device == 'cuda':
    generator_A = torch.nn.DataParallel(generator_A)
    generator_B = torch.nn.DataParallel(generator_B)
    discriminator_A = torch.nn.DataParallel(discriminator_A)
    discriminator_B = torch.nn.DataParallel(discriminator_B)

chained_gen_params = chain(generator_A.parameters(), generator_B.parameters())
chained_dis_params = chain(discriminator_A.parameters(),
                           discriminator_B.parameters())

optim_gen = torch.optim.Adam(chained_gen_params,
                             lr=LEARNING_RATE,
                             betas=(0.5, 0.999),
Esempio n. 12
0
lr = 0.000001
momentum = 0.7

# Input normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

data = datasets.MNIST('data', train=True, download=True, transform=transform)
train = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

D = Discriminator()
G = Generator()

D.to(device)
G.to(device)

if device == 'cuda':
    D = torch.nn.DataParallel(D)
    G = torch.nn.DataParallel(G)

d_optimizer = torch.optim.Adam(D.parameters(), lr, betas=(0.5, 0.99))
g_optimizer = torch.optim.Adam(G.parameters(), lr, betas=(0.5, 0.99))
loss = nn.BCELoss()

for epoch in range(EPOCH):
    for x, y in train:

        batch = x.size()[0]
Esempio n. 13
0
args = parser.parse_args()


# Load data.
print("==> Loading data...")
trainloader, testloader = data_loader_and_transformer(args.batch_size)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model.
print("==> Initializing model...")
model = Discriminator()
if args.load_checkpoint:
    print("==> Loading checkpoint...")
    model = torch.load('cifar10.model')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)


# --------------------------
# Single train and test step
# --------------------------

def train(epoch):
    # Set to train mode.
    model.train()

    # Avoid the potential overflow error from Adam.
    if epoch > 10:
Esempio n. 14
0
class CycleGAN(AlignmentModel):
    """This class implements the alignment model for GAN networks with two generators and two discriminators
    (cycle GAN). For description of the implemented functions, refer to the alignment model."""
    def __init__(self,
                 device,
                 config,
                 generator_a=None,
                 generator_b=None,
                 discriminator_a=None,
                 discriminator_b=None):
        """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam
        optimizers for all models."""
        super().__init__(device, config)
        self.epoch_losses = [0., 0., 0., 0.]

        if generator_a is None:
            generator_a_conf = dict(
                dim_1=config['dim_b'],
                dim_2=config['dim_a'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_a = Generator(generator_a_conf, device)
            self.generator_a.to(device)
        else:
            self.generator_a = generator_a
        if 'optimizer' in config:
            self.optimizer_g_a = OPTIMIZERS[config['optimizer']](
                self.generator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters())
        else:
            self.optimizer_g_a = torch.optim.Adam(
                self.generator_a.parameters(), config['learning_rate'])

        if generator_b is None:
            generator_b_conf = dict(
                dim_1=config['dim_a'],
                dim_2=config['dim_b'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_b = Generator(generator_b_conf, device)
            self.generator_b.to(device)
        else:
            self.generator_b = generator_b
        if 'optimizer' in config:
            self.optimizer_g_b = OPTIMIZERS[config['optimizer']](
                self.generator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters())
        else:
            self.optimizer_g_b = torch.optim.Adam(
                self.generator_b.parameters(), config['learning_rate'])

        if discriminator_a is None:
            discriminator_a_conf = dict(
                dim=config['dim_a'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_a = Discriminator(discriminator_a_conf, device)
            self.discriminator_a.to(device)
        else:
            self.discriminator_a = discriminator_a
        if 'optimizer' in config:
            self.optimizer_d_a = OPTIMIZERS[config['optimizer']](
                self.discriminator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters())
        else:
            self.optimizer_d_a = torch.optim.Adam(
                self.discriminator_a.parameters(), config['learning_rate'])

        if discriminator_b is None:
            discriminator_b_conf = dict(
                dim=config['dim_b'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_b = Discriminator(discriminator_b_conf, device)
            self.discriminator_b.to(device)
        else:
            self.discriminator_b = discriminator_b
        if 'optimizer' in config:
            self.optimizer_d_b = OPTIMIZERS[config['optimizer']](
                self.discriminator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters())
        else:
            self.optimizer_d_b = torch.optim.Adam(
                self.discriminator_b.parameters(), config['learning_rate'])

    def train(self):
        self.generator_a.train()
        self.generator_b.train()
        self.discriminator_a.train()
        self.discriminator_b.train()

    def eval(self):
        self.generator_a.eval()
        self.generator_b.eval()
        self.discriminator_a.eval()
        self.discriminator_b.eval()

    def zero_grad(self):
        self.optimizer_g_a.zero_grad()
        self.optimizer_g_b.zero_grad()
        self.optimizer_d_a.zero_grad()
        self.optimizer_d_b.zero_grad()

    def optimize_all(self):
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def optimize_generator(self):
        """Do the optimization step only for generators (e.g. when training generators and discriminators separately or
        in turns)."""
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()

    def optimize_discriminator(self):
        """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately
        or in turns)."""
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def change_lr(self, factor):
        self.current_lr = self.current_lr * factor
        for param_group in self.optimizer_g_a.param_groups:
            param_group['lr'] = self.current_lr
        for param_group in self.optimizer_g_b.param_groups:
            param_group['lr'] = self.current_lr

    def update_losses_batch(self, *losses):
        loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses
        self.epoch_losses[0] += loss_g_a
        self.epoch_losses[1] += loss_g_b
        self.epoch_losses[2] += loss_d_a
        self.epoch_losses[3] += loss_d_b

    def complete_epoch(self, epoch_metrics):
        self.metrics.append(epoch_metrics + [sum(self.epoch_losses)])
        self.losses.append(self.epoch_losses)
        self.epoch_losses = [0., 0., 0., 0.]

    def print_epoch_info(self):
        print(
            f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} "
            f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}"
        )

    def copy_model(self):
        self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\
                          deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict())

    def restore_model(self):
        self.generator_a.load_state_dict(self.model_copy[0])
        self.generator_b.load_state_dict(self.model_copy[1])
        self.discriminator_a.load_state_dict(self.model_copy[2])
        self.discriminator_b.load_state_dict(self.model_copy[3])

    def export_model(self, test_results, description=None):
        if description is None:
            description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}"
        export_cyclegan_alignment(description, self.config, self.generator_a,
                                  self.generator_b, self.discriminator_a,
                                  self.discriminator_b, self.metrics)
        save_alignment_test_results(test_results, description)
        print(f"Saved model to directory {description}.")

    @classmethod
    def load_model(cls, name, device):
        generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment(
            name, device)
        model = cls(device, config, generator_a, generator_b, discriminator_a,
                    discriminator_b)
        return model
Esempio n. 15
0
class Model:
    def __init__(self, latent_vector_size, generator_feature_map_size,
                 discriminator_feature_map_size, target_channels):
        # creating models
        self.latent_vector_size = latent_vector_size
        self.generator_feature_map_size = generator_feature_map_size
        self.discriminator_feature_map_size = discriminator_feature_map_size
        self.target_channels = target_channels

        self.netG = Generator(
            latent_vector_size=latent_vector_size,
            generator_feature_map_size=generator_feature_map_size,
            output_channel=target_channels)

        self.netD = Discriminator(
            discriminator_feature_map_size=discriminator_feature_map_size,
            input_channel=target_channels)

        # Establish convention for real and fake labels during training

    def train(self,
              train_data,
              optimizerG,
              optimizerD,
              fixed_noise,
              epochs=5,
              criterion=torch.nn.BCELoss(),
              device="cuda",
              dali=False):
        # Lists to keep track of progress
        G_losses = []
        D_losses = []
        iters = 0
        real_label = 1.0
        fake_label = 0.0
        checkpoint = len(train_data) // 4

        for epoch in range(epochs):
            # For each batch in the dataloader
            for i, data in enumerate(train_data):
                start = time.time()
                if dali:
                    d = data[0]
                    train_images, train_labels = d["data"], d["label"]
                    train_images = train_images.to(device)
                    train_images = train_images.permute(0, 3, 1, 2)
                else:
                    train_images, train_labels = data
                    train_images = train_images.to(device)
                self.netD.to(device)
                self.netG.to(device)

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                ## Train with all-real batch
                self.netD.zero_grad()
                # Format batch
                b_size = train_images.size(0)
                label = torch.full((b_size, ),
                                   real_label,
                                   dtype=torch.float,
                                   device=device)
                # Forward pass real batch through D
                output = self.netD(train_images).view(-1)
                # Calculate loss on all-real batch
                errD_real = criterion(output, label)
                # Calculate gradients for D in backward pass
                errD_real.backward()
                D_x = output.mean().item()

                ## Train with all-fake batch
                # Generate batch of latent vectors
                noise = torch.randn(b_size,
                                    self.latent_vector_size,
                                    1,
                                    1,
                                    device=device)
                # Generate fake image batch with G
                fake = self.netG(noise)
                label.fill_(fake_label)
                # Classify all fake batch with D
                output = self.netD(fake.detach()).view(-1)
                # Calculate D's loss on the all-fake batch
                errD_fake = criterion(output, label)
                # Calculate the gradients for this batch
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                # Add the gradients from the all-real and all-fake batches
                errD = errD_real + errD_fake
                # Update D
                optimizerD.step()

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                self.netG.zero_grad()
                label.fill_(
                    real_label)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = self.netD(fake).view(-1)
                # Calculate G's loss based on this output
                errG = criterion(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                optimizerG.step()

                # Output training stats
                if i % checkpoint == 0:
                    print(
                        f"[{epoch}/{epochs}] [{i}/{len(train_data)}] \tLoss_D: {errD.item()} \tLoss_G: {errG.item()} \tD(x): {D_x} \tD(G(z)): {D_G_z1}/{D_G_z2}\ttime:{time.time()-start}"
                    )

                # Save Losses for plotting later
                G_losses.append(errG.item())
                D_losses.append(errD.item())

                # Check how the generator is doing by saving G's output on fixed_noise
                if (iters % checkpoint == 0) or ((epoch == epochs - 1) and
                                                 (i == len(train_data) - 1)):
                    with torch.no_grad():
                        fake = self.netG(fixed_noise).detach().cpu()
                    fig = plt.figure(figsize=(8, 8))
                    plt.axis("off")
                    plt.imshow(
                        np.transpose(
                            vutils.make_grid(fake, padding=2, normalize=True),
                            (1, 2, 0)))
                    plt.savefig(f"./progress/{epoch}_{iters}.png")
                    torch.save(self.netD, f"./models/{epoch}_{iters}.pkl")
                    np.save("./models/G_losses.npy", np.array(G_losses))
                    np.save("./models/D_losses.npy", np.array(D_losses))
                    plt.close(fig)

                iters += 1
def _main():
    print_gpu_details()
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    train_root = args.train_path

    image_size = 256
    cropped_image_size = 256
    print("set image folder")
    train_set = dset.ImageFolder(root=train_root,
                                 transform=transforms.Compose([
                                     transforms.Resize(image_size),
                                     transforms.CenterCrop(cropped_image_size),
                                     transforms.ToTensor()
                                 ]))

    normalizer_clf = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    normalizer_discriminator = transforms.Compose([
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    print('set data loader')
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    # Network creation
    classifier = torch.load(args.classifier_path)
    classifier.eval()
    generator = Generator(gen_type=args.gen_type)
    discriminator = Discriminator(args.discriminator_norm, dis_type=args.gen_type)
    # init weights
    if args.generator_path is not None:
        generator.load_state_dict(torch.load(args.generator_path))
    else:
        generator.init_weights()
    if args.discriminator_path is not None:
        discriminator.load_state_dict(torch.load(args.discriminator_path))
    else:
        discriminator.init_weights()

    classifier.to(device)
    generator.to(device)
    discriminator.to(device)

    # losses + optimizers
    criterion_discriminator, criterion_generator = get_wgan_losses_fn()
    criterion_features = nn.L1Loss()
    criterion_diversity_n = nn.L1Loss()
    criterion_diversity_d = nn.L1Loss()
    generator_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999))

    num_of_epochs = args.epochs

    starting_time = time.time()
    iterations = 0
    # creating dirs for keeping models checkpoint, temp created images, and loss summary
    outputs_dir = os.path.join('wgan-gp_models', args.model_name)
    if not os.path.isdir(outputs_dir):
        os.makedirs(outputs_dir, exist_ok=True)
    temp_results_dir = os.path.join(outputs_dir, 'temp_results')
    if not os.path.isdir(temp_results_dir):
        os.mkdir(temp_results_dir)
    models_dir = os.path.join(outputs_dir, 'models_checkpoint')
    if not os.path.isdir(models_dir):
        os.mkdir(models_dir)
    writer = tensorboardX.SummaryWriter(os.path.join(outputs_dir, 'summaries'))

    z = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for sampling
    z2 = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for diversity sampling
    fixed_features = 0
    fixed_masks = 0
    fixed_features_diversity = 0
    first_iter = True
    print("Starting Training Loop...")
    for epoch in range(num_of_epochs):
        for data in train_loader:
            train_type = random.choices([1, 2], [args.train1_prob, 1-args.train1_prob]) # choose train type
            iterations += 1
            if iterations % 30 == 1:
                print('epoch:', epoch, ', iter', iterations, 'start, time =', time.time() - starting_time, 'seconds')
                starting_time = time.time()
            images, _ = data
            images = images.to(device)  # change to gpu tensor
            images_discriminator = normalizer_discriminator(images)
            images_clf = normalizer_clf(images)
            _, features = classifier(images_clf)
            if first_iter: # save batch of images to keep track of the model process
                first_iter = False
                fixed_features = [torch.clone(features[x]) for x in range(len(features))]
                fixed_masks = [torch.ones(features[x].shape, device=device) for x in range(len(features))]
                fixed_features_diversity = [torch.clone(features[x]) for x in range(len(features))]
                for i in range(len(features)):
                    for j in range(fixed_features_diversity[i].shape[0]):
                        fixed_features_diversity[i][j] = fixed_features_diversity[i][j % 8]
                grid = vutils.make_grid(images_discriminator, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images.jpg'))
                orig_images_diversity = torch.clone(images_discriminator)
                for i in range(orig_images_diversity.shape[0]):
                    orig_images_diversity[i] = orig_images_diversity[i % 8]
                grid = vutils.make_grid(orig_images_diversity, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images_diversity.jpg'))
            # Select a features layer to train on
            features_to_train = random.randint(1, len(features) - 2) if args.fixed_layer is None else args.fixed_layer
            # Set masks
            masks = [features[i].clone() for i in range(len(features))]
            setMasksPart1(masks, device, features_to_train) if train_type == 1 else setMasksPart2(masks, device, features_to_train)
            discriminator_loss_dict = train_discriminator(generator, discriminator, criterion_discriminator, discriminator_optimizer, images_discriminator, features, masks)
            for k, v in discriminator_loss_dict.items():
                writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=iterations)
                if iterations % 30 == 1:
                    print('{}: {:.6f}'.format(k, v))
            if iterations % args.discriminator_steps == 1:
                generator_loss_dict = train_generator(generator, discriminator, criterion_generator, generator_optimizer, images.shape[0], features,
                                                      criterion_features, features_to_train, classifier, normalizer_clf, criterion_diversity_n,
                                                      criterion_diversity_d, masks, train_type)

                for k, v in generator_loss_dict.items():
                    writer.add_scalar('G/%s' % k, v.data.cpu().numpy(), global_step=iterations//5 + 1)
                    if iterations % 30 == 1:
                        print('{}: {:.6f}'.format(k, v))

            # Save generator and discriminator weights every 1000 iterations
            if iterations % 1000 == 1:
                torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G')
                torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D')
            # Save temp results
            if args.keep_temp_results:
                if iterations < 10000 and iterations % 1000 == 1 or iterations % 2000 == 1:
                    # regular sampling (batch of different images)
                    first_features = True
                    fake_images = None
                    fake_images_diversity = None
                    for i in range(1, 5):
                        one_layer_mask = isolate_layer(fixed_masks, i, device)
                        if first_features:
                            first_features = False
                            fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images_diversity = sample(generator, z, fixed_features_diversity, one_layer_mask)
                        else:
                            tmp_fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images = torch.vstack((fake_images, tmp_fake_images))
                            tmp_fake_images = sample(generator, z2, fixed_features_diversity, one_layer_mask)
                            fake_images_diversity = torch.vstack((fake_images_diversity, tmp_fake_images))
                    grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'res_iter_{}.jpg'.format(iterations // 1000)))
                    # diversity sampling (8 different images each with few different noises)
                    grid = vutils.make_grid(fake_images_diversity, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'div_iter_{}.jpg'.format(iterations // 1000)))

                if iterations % 20000 == 1:
                    torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G_' + str(iterations // 15000))
                    torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D_' + str(iterations // 15000))
Esempio n. 17
0
def main():
    ################################
    ## 第一模块:数据准备工作
    data_ = data.Data(args.data_dir, args.vocab_size)

    # 对ICD tree 处理
    parient_children, level2_parients, leafNodes, adj, node2id, hier_dicts = utils.build_tree(
        os.path.join(args.data_dir, 'note_labeled_v2.csv'))
    graph = utils.generate_graph(parient_children, node2id)
    args.node2id = node2id
    args.id2node = {id: node for node, id in node2id.items()}
    args.adj = torch.Tensor(adj).long().to(args.device)
    # args.leafNodes=leafNodes
    args.hier_dicts = hier_dicts
    # args.level2_parients=level2_parients
    #print('836:',args.id2node.get(836),args.id2node.get(0))

    # TODO batcher对象的细节
    g_batcher = GenBatcher(data_, args)

    #################################
    ## 第二模块: 创建G模型,并预训练 G模型
    # TODO Generator对象的细节
    gen_model_eval = Generator(args, data_, graph, level2_parients)
    gen_model_target = Generator(args, data_, graph, level2_parients)
    gen_model_target.eval()
    print(gen_model_eval)

    # for name,param in gen_model_eval.named_parameters():
    #     print(name,param.size(),type(param))
    buffer = ReplayBuffer(capacity=100000)
    gen_model_eval.to(args.device)
    gen_model_target.to(args.device)

    # TODO generated 对象的细节

    # 预训练 G模型
    #pre_train_generator(gen_model,g_batcher,10)

    #####################################
    ## 第三模块: 创建 D模型,并预训练 D模型
    d_model = Discriminator(args)
    d_model.to(args.device)

    # 预训练 D模型
    #pre_train_discriminator(d_model,d_batcher,25)

    ########################################
    ## 第四模块: 交替训练G和D模型

    #将评估结果写入文件中
    f = open('valid_result.csv', 'w')
    writer = csv.writer(f)
    writer.writerow([
        'avg_micro_p', 'avg_macro_p', 'avg_micro_r,avg_macro_r',
        'avg_micro_f1', 'avg_macro_f1', 'avg_micro_auc_roc',
        'avg_macro_auc_roc'
    ])
    epoch_f = []
    for epoch in range(args.num_epochs):
        batches = g_batcher.get_batches(mode='train')
        print('number of batches:', len(batches))
        for step in range(len(batches)):
            #print('step:',step)
            current_batch = batches[step]
            ehrs = [example.ehr for example in current_batch]
            ehrs = torch.Tensor(ehrs).long().to(args.device)

            hier_labels = [example.hier_labels for example in current_batch]

            true_labels = []

            # 对hier_labels进行填充
            for i in range(len(hier_labels)):  # i为样本索引
                for j in range(len(hier_labels[i])):  # j为每个样本的每条路径索引
                    if len(hier_labels[i][j]) < 4:
                        hier_labels[i][j] = hier_labels[i][j] + [0] * (
                            4 - len(hier_labels[i][j]))
                # if len(hier_labels[i]) < args.k:
                #     for time in range(args.k - len(hier_labels[i])):
                #         hier_labels[i].append([0] * args.hops)

            for sample in hier_labels:
                #print('sample:',sample)
                true_labels.append([row[1] for row in sample])

            predHierLabels, batchStates_n, batchHiddens_n = generator.generated_negative_samples(
                gen_model_eval, d_model, ehrs, hier_labels, buffer)

            #true_labels = [example.labels for example in current_batch]

            _, _, avgJaccard = full_eval.process_labels(
                predHierLabels, true_labels, args)

            # G生成训练D的positive samples
            batchStates_p, batchHiddens_p = generator.generated_positive_samples(
                gen_model_eval, ehrs, hier_labels, buffer)

            # 训练 D网络
            #d_loss=train_discriminator(d_model,batchStates_n,batchHiddens_n,batchStates_p,batchHiddens_p,mode=args.mode)

            # 训练 G模型
            #for g_epoch in range(10):
            g_loss = train_generator(gen_model_eval,
                                     gen_model_target,
                                     d_model,
                                     batchStates_n,
                                     batchHiddens_n,
                                     buffer,
                                     mode=args.mode)

            print('batch_number:{}, avgJaccard:{:.4f}, g_loss:{:.4f}'.format(
                step, avgJaccard, g_loss))

        # #每经过一个epoch 之后分别评估G 模型的表现以及D模型的表现(在验证集上的表现)
        avg_micro_f1 = evaluate(g_batcher,
                                gen_model_eval,
                                d_model,
                                buffer,
                                writer,
                                flag='valid')
        epoch_f.append(avg_micro_f1)

    # 画图
    # plot results
    window = int(args.num_epochs / 20)
    print('window:', window)
    fig, ((ax1), (ax2)) = plt.subplots(2, 1, sharey=True, figsize=[9, 9])
    rolling_mean = pd.Series(epoch_f).rolling(window).mean()
    std = pd.Series(epoch_f).rolling(window).std()
    ax1.plot(rolling_mean)
    ax1.fill_between(range(len(epoch_f)),
                     rolling_mean - std,
                     rolling_mean + std,
                     color='orange',
                     alpha=0.2)
    ax1.set_title(
        'Episode Length Moving Average ({}-episode window)'.format(window))
    ax1.set_xlabel('Epoch Number')
    ax1.set_ylabel('F1')

    ax2.plot(epoch_f)
    ax2.set_title('Performance on valid set')
    ax2.set_xlabel('Epoch Number')
    ax2.set_ylabel('F1')

    fig.tight_layout(pad=2)
    plt.show()
    fig.savefig('results.png')

    f.close()
Esempio n. 18
0
def semi_main(options):
    print('\nSemi-Supervised Learning!\n')

    # 1. Make sure the options are valid argparse CLI options indeed
    assert isinstance(options, argparse.Namespace)

    # 2. Set up the logger
    logging.basicConfig(level=str(options.loglevel).upper())

    # 3. Make sure the output dir `outf` exists
    _check_out_dir(options)

    # 4. Set the random state
    _set_random_state(options)

    # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch
    device = _configure_cuda(options)

    # 6. Prepare the datasets and split it for semi-supervised learning
    if options.dataset != 'cifar10':
        raise NotImplementedError(
            'Semi-supervised learning only support CIFAR10 dataset at the moment!'
        )
    test_data_loader, semi_data_loader, train_data_loader = _prepare_semi_dataset(
        options)

    # 7. Set the parameters
    ngpu = int(options.ngpu)  # num of GPUs
    nz = int(
        options.nz)  # size of latent vector, also the number of the generators
    ngf = int(options.ngf)  # depth of feature maps through G
    ndf = int(options.ndf)  # depth of feature maps through D
    nc = int(options.nc
             )  # num of channels of the input images, 3 indicates color images
    M = int(options.mcmc)  # num of SGHMC chains run concurrently
    nd = int(options.nd)  # num of discriminators
    nsetz = int(options.nsetz)  # num of noise batches

    # 8. Special preparations for Bayesian GAN for Generators

    # In order to inject the SGHMAC into the training process, instead of pause the gradient descent at
    # each training step, which can be easily defined with static computation graph(Tensorflow), in PyTorch,
    # we have to move the Generator Sampling to the very beginning of the whole training process, and use
    # a trick that initializing all of the generators explicitly for later usages.
    Generator_chains = []
    for _ in range(nsetz):
        for __ in range(M):
            netG = Generator(ngpu, nz, ngf, nc).to(device)
            netG.apply(weights_init)
            Generator_chains.append(netG)

    logging.info(
        f'Showing the first generator of the Generator chain: \n {Generator_chains[0]}\n'
    )

    # 9. Special preparations for Bayesian GAN for Discriminators
    assert options.dataset == 'cifar10', 'Semi-supervised learning only support CIFAR10 dataset at the moment!'

    num_class = 10 + 1

    # To simplify the implementation we only consider the situation of 1 discriminator
    # if nd <= 1:
    #     netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    #     netD.apply(weights_init)
    # else:
    # Discriminator_chains = []
    # for _ in range(nd):
    #     for __ in range(M):
    #         netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    #         netD.apply(weights_init)
    #         Discriminator_chains.append(netD)

    netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    netD.apply(weights_init)
    logging.info(f'Showing the Discriminator model: \n {netD}\n')

    # 10. Loss function
    criterion = nn.CrossEntropyLoss()
    all_criterion = ComplementCrossEntropyLoss(except_index=0, device=device)

    # 11. Set up optimizers
    optimizerG_chains = [
        optim.Adam(netG.parameters(),
                   lr=options.lr,
                   betas=(options.beta1, 0.999)) for netG in Generator_chains
    ]

    # optimizerD_chains = [
    #     optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netD in Discriminator_chains
    # ]
    optimizerD = optim.Adam(netD.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))
    import math
    # 12. Set up the losses for priors and noises
    gprior = PriorLoss(prior_std=1., total=500.)
    gnoise = NoiseLoss(params=Generator_chains[0].parameters(),
                       device=device,
                       scale=math.sqrt(2 * options.alpha / options.lr),
                       total=500.)
    dprior = PriorLoss(prior_std=1., total=50000.)
    dnoise = NoiseLoss(params=netD.parameters(),
                       device=device,
                       scale=math.sqrt(2 * options.alpha * options.lr),
                       total=50000.)

    gprior.to(device=device)
    gnoise.to(device=device)
    dprior.to(device=device)
    dnoise.to(device=device)

    # In order to let G condition on a specific noise, we attach the noise to a fixed Tensor
    fixed_noise = torch.FloatTensor(options.batchSize, options.nz, 1,
                                    1).normal_(0, 1).to(device=device)
    inputT = torch.FloatTensor(options.batchSize, 3, options.imageSize,
                               options.imageSize).to(device=device)
    noiseT = torch.FloatTensor(options.batchSize, options.nz, 1,
                               1).to(device=device)
    labelT = torch.FloatTensor(options.batchSize).to(device=device)
    real_label = 1
    fake_label = 0

    # 13. Transfer all the tensors and modules to GPU if applicable
    # for netD in Discriminator_chains:
    #     netD.to(device=device)
    netD.to(device=device)

    for netG in Generator_chains:
        netG.to(device=device)
    criterion.to(device=device)
    all_criterion.to(device=device)

    # ========================
    # === Training Process ===
    # ========================

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    stats = []
    iters = 0

    try:
        print("\nStarting Training Loop...\n")
        for epoch in range(options.niter):
            top1 = Metrics()
            for i, data in enumerate(train_data_loader, 0):
                # ##################
                # Train with real
                # ##################
                netD.zero_grad()
                real_cpu = data[0].to(device)
                batch_size = real_cpu.size(0)
                # label = torch.full((batch_size,), real_label, device=device)

                inputT.resize_as_(real_cpu).copy_(real_cpu)
                labelT.resize_(batch_size).fill_(real_label)

                inputv = torch.autograd.Variable(inputT)
                labelv = torch.autograd.Variable(labelT)

                output = netD(inputv)
                errD_real = all_criterion(output)
                errD_real.backward()
                D_x = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()

                # ##################
                # Train with fake
                # ##################
                fake_images = []
                for i_z in range(nsetz):
                    noiseT.resize_(batch_size, nz, 1, 1).normal_(
                        0, 1)  # prior, sample from N(0, 1) distribution
                    noisev = torch.autograd.Variable(noiseT)
                    for m in range(M):
                        idx = i_z * M + m
                        netG = Generator_chains[idx]
                        _fake = netG(noisev)
                        fake_images.append(_fake)
                # output = torch.stack(fake_images)
                fake = torch.cat(fake_images)
                output = netD(fake.detach())

                labelv = torch.autograd.Variable(
                    torch.LongTensor(fake.data.shape[0]).to(
                        device=device).fill_(fake_label))
                errD_fake = criterion(output, labelv)
                errD_fake.backward()
                D_G_z1 = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()

                # ##################
                # Semi-supervised learning
                # ##################
                for ii, (input_sup, target_sup) in enumerate(semi_data_loader):
                    input_sup, target_sup = input_sup.to(
                        device=device), target_sup.to(device=device)
                    break
                input_sup_v = input_sup.to(device=device)
                target_sup_v = (target_sup + 1).to(device=device)
                output_sup = netD(input_sup_v)
                err_sup = criterion(output_sup, target_sup_v)
                err_sup.backward()
                pred1 = accuracy(output_sup.data, target_sup + 1,
                                 topk=(1, ))[0]
                top1.update(value=pred1.item(), N=input_sup.size(0))

                errD_prior = dprior(netD.parameters())
                errD_prior.backward()
                errD_noise = dnoise(netD.parameters())
                errD_noise.backward()
                errD = errD_real + errD_fake + err_sup + errD_prior + errD_noise
                optimizerD.step()

                # ##################
                # Sample and construct generator(s)
                # ##################
                for netG in Generator_chains:
                    netG.zero_grad()
                labelv = torch.autograd.Variable(
                    torch.FloatTensor(fake.data.shape[0]).to(
                        device=device).fill_(real_label))
                output = netD(fake)
                errG = all_criterion(output)

                for netG in Generator_chains:
                    errG = errG + gprior(netG.parameters())
                    errG = errG + gnoise(netG.parameters())
                errG.backward()
                D_G_z2 = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()
                for optimizerG in optimizerG_chains:
                    optimizerG.step()

                # ##################
                # Evaluate testing accuracy
                # ##################
                # Pause and compute the test accuracy after every 10 times of the notefreq
                if iters % 10 * int(options.notefreq) == 0:
                    # get test accuracy on train and test
                    netD.eval()
                    compute_test_accuracy(discriminator=netD,
                                          testing_data_loader=test_data_loader,
                                          device=device)
                    netD.train()

                # ##################
                # Note down
                # ##################
                # Report status for the current iteration
                training_status = f"[{epoch}/{options.niter}][{i}/{len(train_data_loader)}] Loss_D: {errD.item():.4f} " \
                                  f"Loss_G: " \
                                  f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}" \
                                  f" | Acc {top1.value:.1f} / {top1.mean:.1f}"
                print(training_status)

                # Save samples to disk
                if i % int(options.notefreq) == 0:
                    vutils.save_image(
                        real_cpu,
                        f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png",
                        normalize=True)
                    for _iz in range(nsetz):
                        for _m in range(M):
                            gidx = _iz * M + _m
                            netG = Generator_chains[gidx]
                            fake = netG(fixed_noise)
                            vutils.save_image(
                                fake.detach(),
                                f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}_z{_iz}_m{_m}.png",
                                normalize=True)

                    # Save Losses statistics for post-mortem
                    G_losses.append(errG.item())
                    D_losses.append(errD.item())
                    stats.append(training_status)

                    # # Check how the generator is doing by saving G's output on fixed_noise
                    # if (iters % 500 == 0) or ((epoch == options.niter - 1) and (i == len(data_loader) - 1)):
                    #     with torch.no_grad():
                    #         fake = netG(fixed_noise).detach().cpu()
                    #     img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                    iters += 1
            # TODO: find an elegant way to support saving checkpoints in Bayesian GAN context
    except Exception as e:
        print(e)

        # save training stats no matter what kind of errors occur in the processes
        _save_stats(statistic=G_losses, save_name='G_losses', options=options)
        _save_stats(statistic=D_losses, save_name='D_losses', options=options)
        _save_stats(statistic=stats,
                    save_name='Training_stats',
                    options=options)
Esempio n. 19
0
        x = i // 10
        y = i % 10
        m[:, x * SZ:(x + 1) * SZ, y * SZ:(y + 1) * SZ] = imgs[i]
    m = m.transpose((1, 2, 0))
    plt.imsave(name, m)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = loader.load()
dataloader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

generator = Generator(latent_dim, img_dim)
discriminator = Discriminator(img_dim)
generator.to(device)
discriminator.to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

num_epochs = 20
rd = torch.distributions.Normal(0, 1)
for epoch in range(num_epochs):
    generator.train()
    for it, data in enumerate(dataloader):
        img0 = data
        # print(img0.shape)

        latent = rd.sample((img0.shape[0], latent_dim)).to(device)
        img1 = generator(latent)
        img0 = img0.to(device)
        output0 = discriminator(img0)