Exemple #1
0
    def __init__(self, hparams):
        super().__init__()

        print("Initializing our HandGAN model...")

        # Workaround from https://github.com/PyTorchLightning/pytorch-lightning/issues/3998
        # Happens when loading model from checkpoints. save_hyperparameters() not working
        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)

        #self.save_hyperparameters()
        self.hparams = hparams

        # Used to initialize the networks
        init = mutils.Initializer(init_type=hparams.init_type,
                                  init_gain=hparams.init_gain)

        # Network architecture
        # Two generators, one for each domain:
        self.g_ab = init(generator(
            hparams))  #  - g_ab: translation from domain A to domain B
        self.g_ba = init(generator(
            hparams))  #  - g_ba: translation from domain B to domain A

        # Discriminators:
        self.d_a = init(
            discriminator(hparams))  #  - d_a: domain A discriminator
        self.d_b = init(
            discriminator(hparams))  #  - d_b: domain B discriminator

        # For the perceptual discriminator we will need a feature extractor
        if hparams.netD == 'perceptual':
            self.vgg_net = Vgg16().eval()

        # For validation we will need Inception network to compute FID metric
        if hparams.valid_interval > 0.0:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
            self.inception_net = InceptionV3([block_idx]).eval()

        if hparams.ganerated:
            model = SilNet.load_from_checkpoint(
                "./weights/silnet_pretrained.ckpt")
            mutils.set_requires_grad(model, requires_grad=False)
            self.silnet = model.unet.eval()

        # ImagePool from where we randomly get generated images in both domains
        self.fake_a_pool = mutils.ImagePool(hparams.pool_size)
        self.fake_b_pool = mutils.ImagePool(hparams.pool_size)

        # Criterions
        self.crit_cycle = torch.nn.L1Loss()
        self.crit_discr = DiscriminatorLoss('lsgan')

        if hparams.lambda_idt > 0.0:
            self.crit_idt = torch.nn.L1Loss()

        if hparams.ganerated:
            self.crit_geom = torch.nn.BCEWithLogitsLoss()
    def __init__(self, vis_screen, save_path, l1_coef, l2_coef,
                 pre_trained_gen, pre_trained_disc, batch_size, num_workers,
                 epochs, inference, softmax_coef, image_size, lr_D, lr_G,
                 audio_seconds):

        # initializing the generator and discriminator modules.
        self.generator = generator(image_size, audio_seconds * 16000).cuda()
        self.discriminator = discriminator(image_size).cuda()

        # if pre_trained_disc is true, load the already learned parameters. Else, initialize weights randomly
        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        # the same as with the discriminator
        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        # initializing other parameters
        self.inference = inference
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.l1_coef = l1_coef
        self.l2_coef = l2_coef
        self.softmax_coef = softmax_coef
        self.lr_D = lr_D
        self.lr_G = lr_G

        # building the data_loader
        self.dataset = dataset_builder(transform=Rescale(int(self.image_size)),
                                       inference=self.inference,
                                       audio_seconds=audio_seconds)
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        # defining optimizers. Keeping all the parameters from the list of Module.parameters() for which requires_grad is TRUE
        self.optimD = torch.optim.Adam(
            [p for p in self.discriminator.parameters() if p.requires_grad],
            lr=self.lr_D,
            betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(
            [p for p in self.generator.parameters() if p.requires_grad],
            lr=self.lr_G,
            betas=(self.beta1, 0.999))

        # initializing a Logger in which we will create Log files and defining the directory name in which store checkpoints.
        self.logger = Logger(vis_screen, save_path)
        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path
Exemple #3
0
    def __init__(self, vis_screen, save_path, l1_coef, l2_coef,
                 pre_trained_gen, pre_trained_disc, batch_size, num_workers,
                 epochs, inference, softmax_coef, image_size, lr_D, lr_G,
                 audio_seconds):

        self.generator = generator(image_size, audio_seconds * 16000).cuda()
        self.discriminator = discriminator(image_size).cuda()

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        self.inference = inference
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.l1_coef = l1_coef
        self.l2_coef = l2_coef
        self.softmax_coef = softmax_coef
        self.lr_D = lr_D
        self.lr_G = lr_G

        self.dataset = dataset_builder(transform=Rescale(int(self.image_size)),
                                       inference=self.inference,
                                       audio_seconds=audio_seconds)
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.discriminator.parameters()),
                                       lr=self.lr_D,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.generator.parameters()),
                                       lr=self.lr_G,
                                       betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen, save_path)
        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path
def dis_trainer(query_seqs,
                response_seqs,
                labels,
                discriminator,
                dis_optim,
                USE_CUDA=True):
    """             
        - query_seqs:       (batch_size, max_seq_len>=dis_opts.conv_padding_len)
        - response_seqs:    (batch_size, max_seq_len>=dis_opts.conv_padding_len)
        - labels:           (batch,)
    """
    batch_size = query_seqs.size(0)
    assert (batch_size == response_seqs.size(0))
    assert (batch_size == labels.size(0))

    # Training mode (enable dropout)
    discriminator.train()

    if USE_CUDA:
        query_seqs = query_seqs.cuda()
        response_seqs = response_seqs.cuda()
        labels = labels.cuda()

    # (batch_size, 2)
    # query_seqs:    (batch, max_seq_len)
    # response_seqs: (batch, max_seq_len)
    pred = discriminator(query_seqs, response_seqs)

    loss = criterion(pred, labels)
    num_corrects = (pred.max(1)[1] == labels).sum().item()

    dis_optim.zero_grad()
    loss.backward()
    dis_optim.step()

    del query_seqs, response_seqs, labels, pred
    torch.cuda.empty_cache()

    return loss.item(), num_corrects
if args.mean_img is None:
    mean_img = None
else:
    mean_img = np.load(args.mean_img)
train_info = get_dataset(args.train_path, norm=args.norm, mean_img=mean_img)
test_info = get_dataset(args.test_path, norm=args.norm, mean_img=mean_img)
train = train_info.__getitem__(slice(0, train_info.__len__(), 1))
test = test_info.__getitem__(slice(0, test_info.__len__(), 1))

with open(args.param_path) as fp:
    param = json.load(fp)
with open(args.arch_path) as fp:
    config = json.load(fp)

predictor = discriminator(config, dr=param['dr'], bn=param['bn'])
model = custom_classifier(predictor=predictor)

if args.mname == 'None':
    args.mname = None

trainer = train_loop()
best_score = trainer(model,
                     train,
                     test,
                     args.out_dir,
                     optname=args.optname,
                     lr=param['lr'],
                     rate=param['rate'],
                     lr_attr=args.lr_attr,
                     gpu=args.gpu,
Exemple #6
0
def train(params):
    if params.visdom.use_visdom:
        import visdom
        vis = visdom.Visdom(server=params.visdom.server, port=params.visdom.port)
        generated = vis.images(
            np.ones((params.train.batch_size, 3, params.visdom.image_size, params.visdom.image_size)),
            opts=dict(title='Generated')
        )

        gt = vis.images(
            np.ones((params.train.batch_size, 3, params.visdom.image_size, params.visdom.image_size)),
            opts=dict(title='Original')
        )

        # initialize visdom loss plots
        g_loss = vis.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1,)).cpu(),
            opts=dict(
                xlabel='Iteration',
                ylabel='Loss',
                title='G Training Loss',
            )
        )
        d_loss = vis.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1,)).cpu(),
            opts=dict(
                xlabel='Iteration',
                ylabel='Loss',
                title='D Training Loss',
            )
        )

    G = Generator(params.network.generator)
    D = discriminator()
    G.apply(weights_init)
    D.apply(weights_init)
    
    if params.restore.G:
        G.load_state_dict(torch.load(params.restore.G))

    if params.restore.D:
        D.load_state_dict(torch.load(params.restore.D))

    d_steps = 1
    g_steps = 1

    criterion = torch.nn.BCELoss()
    d_optimizer = torch.optim.Adam(D.parameters(), lr=params.train.d_learning_rate)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=params.train.g_learning_rate)

    dataset = CatsDataset(params.dataset)
    train_loader = data_utils.DataLoader(
        dataset, batch_size=params.train.batch_size, shuffle=True,
        num_workers=0)

    t_ones = Variable(torch.ones(params.train.batch_size))
    t_zeros = Variable(torch.zeros(params.train.batch_size))

    if params.train.use_cuda:
        t_ones = t_ones.cuda()
        t_zeros = t_zeros.cuda()
        G = G.cuda()
        D = D.cuda()

    for epoch in range(params.train.epochs):
        for p in D.parameters():
            p.requires_grad = True

        for d_index in range(d_steps):
            # 1. Train D on real+fake
            D.zero_grad()

            #  1A: Train D on real
            d_real_data = Variable(next(iter(train_loader)))

            if params.train.use_cuda:
                d_real_data = d_real_data.cuda()

            d_real_decision = D(d_real_data)
            d_real_error = criterion(d_real_decision, t_ones)  # ones = true
            d_real_error.backward()
            # compute/store gradients, but don't change params
            d_optimizer.step()

            #  1B: Train D on fake
            d_gen_input = \
                Variable(torch.FloatTensor(
                    params.train.batch_size, params.network.generator.z_size,
                    1, 1
                ).normal_(0, 1))

            if params.train.use_cuda:
                d_gen_input = d_gen_input.cuda()

            d_fake_input = G(d_gen_input)

            d_fake_decision = D(d_fake_input)
            d_fake_error = criterion(d_fake_decision, t_zeros)  # zeros = fake
            d_fake_error.backward()

            loss = d_fake_error + d_real_error
            np_loss = loss.cpu().data.numpy()

            # Only optimizes D's parameters;
            # changes based on stored gradients from backward()
            d_optimizer.step()

        for p in D.parameters():
            p.requires_grad = False

        for g_index in range(g_steps):
            # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = \
                Variable(torch.FloatTensor(
                    params.train.batch_size, params.network.generator.z_size,
                    1, 1
                ).normal_(0, 1))

            if params.train.use_cuda:
                gen_input = gen_input.cuda()

            dg_fake_decision = D(G(gen_input))
            # we want to fool, so pretend it's all genuine
            g_error = criterion(dg_fake_decision, t_ones)

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters

        if epoch % 10 == 0 and params.visdom.use_visdom:
            vis.line(
                X=np.array([epoch]),
                Y=np.array([np_loss]),
                win=d_loss,
                update='append'
            )

            vis.images(
                255 * (d_fake_input.cpu().data.numpy() / 2.0 + 0.5),
                win=generated
            )

            vis.images(
                255 * (d_real_data.cpu().data.numpy() / 2.0 + 0.5),
                win=gt
            )

            vis.line(
                X=np.array([epoch]),
                Y=np.array([g_error.cpu().data.numpy()]),
                win=g_loss,
                update='append'
            )

        print('Epoch: {0}, D: {1}, G: {2}'.format(
            epoch, d_fake_error.data[0], g_error.data[0]))

        if epoch % params.save.every == 0:
            torch.save(D.state_dict(), params.save.D)
            torch.save(G.state_dict(), params.save.G)