Exemplo n.º 1
0
def main():
    logger.info("Start to train:\n arguments: %s" % str(args))
    #step3: define transform
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         Img_to_zero_center()])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    #step4: define train/test dataloader
    train_dataset = CACD("train", transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    #step5: define model,optim
    model = IPCGANs(lr=args.learning_rate,
                    age_classifier_path=args.age_classifier_path,
                    gan_loss_weight=args.gan_loss_weight,
                    feature_loss_weight=args.feature_loss_weight,
                    age_loss_weight=args.age_loss_weight)
    #,feature_extractor_path=args.feature_extractor_path)
    model.load_generator_state_dict_custom(model_path="model_last.pth")
    d_optim = model.d_optim
    g_optim = model.g_optim

    samples_tqdm = tqdm(enumerate(train_loader, 1), position=0, leave=True)

    for epoch in range(args.max_epoches):
        for idx, (source_img_227,source_img_128,true_label_img,\
               true_label_128,true_label_64,fake_label_64, true_label) in samples_tqdm:

            running_d_loss = None
            running_g_loss = None
            n_iter = epoch * len(train_loader) + idx

            #mv to gpu
            source_img_227 = source_img_227.cuda()
            source_img_128 = source_img_128.cuda()
            true_label_img = true_label_img.cuda()
            true_label_128 = true_label_128.cuda()
            true_label_64 = true_label_64.cuda()
            fake_label_64 = fake_label_64.cuda()
            true_label = true_label.cuda()

            #train discriminator
            for d_iter in range(args.d_iter):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                d_loss = model.d_loss
                running_d_loss = d_loss
                d_loss.backward()
                d_optim.step()

            #visualize params
            for name, param in model.discriminator.named_parameters():
                writer.add_histogram("discriminator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            #train generator
            for g_iter in range(args.g_iter):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                g_loss = model.g_loss
                running_g_loss = g_loss
                g_loss.backward()
                g_optim.step()

            for name, param in model.generator.named_parameters():
                writer.add_histogram("generator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
            samples_tqdm.set_description(
                format_str %
                (idx, len(train_loader), running_g_loss, running_d_loss))

            writer.add_scalars('data/loss', {
                'G_loss': running_g_loss,
                'D_loss': running_d_loss
            }, n_iter)

            # save the parameters at the end of each save interval
            if idx % args.save_interval == 0:
                model.save_model(dir=args.saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth' %
                                 (epoch, idx))
                model.save_model(dir="", filename="model_last.pth")
                logger.info('checkpoint has been created!')

            #val step
            if idx % args.val_interval == 0:
                save_dir = os.path.join(args.saved_validation_folder,
                                        "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128,
                              true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),
                               fp=os.path.join(
                                   save_dir,
                                   "batch_%d_source.jpg" % (val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(args.age_groups):
                        img = model.test_generate(source_img_128,
                                                  true_label_128[age])
                        save_image(Reverse_zero_center()(img),
                                   fp=os.path.join(
                                       save_dir, "batch_%d_age_group_%d.jpg" %
                                       (val_idx, age)))
                logger.info('validation image has been created!')
Exemplo n.º 2
0
def main(args, logger, writer):
    logger.info("Start to train:\n arguments: %s" % str(args))
    content = "[202.*.*.150.] [INFO] Start to train - IPCGANs "
    payload = {"text": content}
    requests.post(webhook_url,
                  data=json.dumps(payload),
                  headers={'Content-Type': 'application/json'})

    #step3: define transform
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         Img_to_zero_center()])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    #step4: define train/test dataloader
    train_dataset = CACD("train", transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    #step5: define model,optim
    model = IPCGANs(lr=args.learning_rate,
                    age_classifier_path=args.age_classifier_path,
                    gan_loss_weight=args.gan_loss_weight,
                    feature_loss_weight=args.feature_loss_weight,
                    age_loss_weight=args.age_loss_weight,
                    generator_path=args.g_checkpoint,
                    discriminator_path=args.d_checkpoint)
    #,feature_extractor_path=args.feature_extractor_path)
    d_optim = model.d_optim
    g_optim = model.g_optim

    for epoch in range(args.max_epoches):
        avr_d_loss = 0
        avr_g_loss = 0
        count = 0
        for idx, (source_img_227,source_img_128,true_label_img,\
               true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader,1):

            running_d_loss = None
            running_g_loss = None
            n_iter = epoch * len(train_loader) + idx

            #mv to gpu
            source_img_227 = source_img_227.cuda()
            source_img_128 = source_img_128.cuda()
            true_label_img = true_label_img.cuda()
            true_label_128 = true_label_128.cuda()
            true_label_64 = true_label_64.cuda()
            fake_label_64 = fake_label_64.cuda()
            true_label = true_label.cuda()

            #train discriminator
            for d_iter in range(args.d_iter):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                d_loss = model.d_loss
                running_d_loss = d_loss
                d_loss.backward()
                d_optim.step()

            #visualize params
            for name, param in model.discriminator.named_parameters():
                writer.add_histogram("discriminator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            #train generator
            for g_iter in range(args.g_iter):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                g_loss = model.g_loss
                running_g_loss = g_loss
                g_loss.backward()
                g_optim.step()

            for name, param in model.generator.named_parameters():
                writer.add_histogram("generator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
            logger.info(
                format_str %
                (idx, len(train_loader), running_g_loss, running_d_loss))

            writer.add_scalars('data/loss', {
                'G_loss': running_g_loss,
                'D_loss': running_d_loss
            }, n_iter)
            avr_g_loss += running_g_loss
            avr_d_loss += running_d_loss
            count += 1

            # save the parameters at the end of each save interval
            if idx % args.save_interval == 0:
                model.save_model(dir=args.saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth' %
                                 (epoch, idx))
                logger.info('checkpoint has been created!')

            #val step
            if idx % args.val_interval == 0:
                save_dir = os.path.join(args.saved_validation_folder,
                                        "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128,
                              true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),
                               filename=os.path.join(
                                   save_dir,
                                   "batch_%d_source.jpg" % (val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(args.age_groups):
                        img = model.test_generate(source_img_128,
                                                  true_label_128[age])
                        save_image(Reverse_zero_center()(img),
                                   filename=os.path.join(
                                       save_dir, "batch_%d_age_group_%d.jpg" %
                                       (val_idx, age)))
                        # if epoch % 3 == 0 and idx % 1000 == 0:
                        #     print(post_image(
                        #         filename=os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)),
                        #         token='xoxp-66111612183-66103666016-826666478608-af2a1c301014db145d3cf92d02b9bdcf',
                        #         channels='CPX0UMK42'))
                logger.info('validation image has been created!')
        avr_d_loss = avr_d_loss / count
        avr_g_loss = avr_g_loss / count

        content = "[202.*.*.150.]  [INFO] Epoch End : " + str(
            epoch) + ", d_loss : " + str(avr_d_loss) + ", g_loss : " + str(
                avr_g_loss)
        payload = {"text": content}
        requests.post(webhook_url,
                      data=json.dumps(payload),
                      headers={'Content-Type': 'application/json'})