示例#1
0
parser.add_argument('--cuda_device',
                    type=str,
                    help='which device to use',
                    default='0')
parser.add_argument('--checkpoint',
                    type=str,
                    help='logs and checkpoints directory',
                    default='./checkpoint/pretrain_alexnet')
parser.add_argument('--saved_model_folder',
                    type=str,
                    help='the path of folder which stores the parameters file',
                    default='./checkpoint/pretrain_alexnet/saved_parameters/')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device

check_dir(args.checkpoint)
check_dir(args.saved_model_folder)

#step2: define logging output
logger = logging.getLogger("Age classifer")
file_handler = logging.FileHandler(join(args.checkpoint, 'log.txt'), "w")
stdout_handler = logging.StreamHandler()
logger.addHandler(file_handler)
logger.addHandler(stdout_handler)
stdout_handler.setFormatter(
    logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
file_handler.setFormatter(
    logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
logger.setLevel(logging.INFO)

示例#2
0
def main():
    logger.info("Start to train:\n arguments: %s" % str(args))
    #step3: define transform
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         Img_to_zero_center()])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    #step4: define train/test dataloader
    train_dataset = CACD("train", transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    #step5: define model,optim
    model = IPCGANs(lr=args.learning_rate,
                    age_classifier_path=args.age_classifier_path,
                    gan_loss_weight=args.gan_loss_weight,
                    feature_loss_weight=args.feature_loss_weight,
                    age_loss_weight=args.age_loss_weight)
    #,feature_extractor_path=args.feature_extractor_path)
    model.load_generator_state_dict_custom(model_path="model_last.pth")
    d_optim = model.d_optim
    g_optim = model.g_optim

    samples_tqdm = tqdm(enumerate(train_loader, 1), position=0, leave=True)

    for epoch in range(args.max_epoches):
        for idx, (source_img_227,source_img_128,true_label_img,\
               true_label_128,true_label_64,fake_label_64, true_label) in samples_tqdm:

            running_d_loss = None
            running_g_loss = None
            n_iter = epoch * len(train_loader) + idx

            #mv to gpu
            source_img_227 = source_img_227.cuda()
            source_img_128 = source_img_128.cuda()
            true_label_img = true_label_img.cuda()
            true_label_128 = true_label_128.cuda()
            true_label_64 = true_label_64.cuda()
            fake_label_64 = fake_label_64.cuda()
            true_label = true_label.cuda()

            #train discriminator
            for d_iter in range(args.d_iter):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                d_loss = model.d_loss
                running_d_loss = d_loss
                d_loss.backward()
                d_optim.step()

            #visualize params
            for name, param in model.discriminator.named_parameters():
                writer.add_histogram("discriminator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            #train generator
            for g_iter in range(args.g_iter):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                g_loss = model.g_loss
                running_g_loss = g_loss
                g_loss.backward()
                g_optim.step()

            for name, param in model.generator.named_parameters():
                writer.add_histogram("generator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
            samples_tqdm.set_description(
                format_str %
                (idx, len(train_loader), running_g_loss, running_d_loss))

            writer.add_scalars('data/loss', {
                'G_loss': running_g_loss,
                'D_loss': running_d_loss
            }, n_iter)

            # save the parameters at the end of each save interval
            if idx % args.save_interval == 0:
                model.save_model(dir=args.saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth' %
                                 (epoch, idx))
                model.save_model(dir="", filename="model_last.pth")
                logger.info('checkpoint has been created!')

            #val step
            if idx % args.val_interval == 0:
                save_dir = os.path.join(args.saved_validation_folder,
                                        "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128,
                              true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),
                               fp=os.path.join(
                                   save_dir,
                                   "batch_%d_source.jpg" % (val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(args.age_groups):
                        img = model.test_generate(source_img_128,
                                                  true_label_128[age])
                        save_image(Reverse_zero_center()(img),
                                   fp=os.path.join(
                                       save_dir, "batch_%d_age_group_%d.jpg" %
                                       (val_idx, age)))
                logger.info('validation image has been created!')
示例#3
0
def main(args, logger, writer):
    logger.info("Start to train:\n arguments: %s" % str(args))
    content = "[202.*.*.150.] [INFO] Start to train - IPCGANs "
    payload = {"text": content}
    requests.post(webhook_url,
                  data=json.dumps(payload),
                  headers={'Content-Type': 'application/json'})

    #step3: define transform
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         Img_to_zero_center()])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    #step4: define train/test dataloader
    train_dataset = CACD("train", transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    #step5: define model,optim
    model = IPCGANs(lr=args.learning_rate,
                    age_classifier_path=args.age_classifier_path,
                    gan_loss_weight=args.gan_loss_weight,
                    feature_loss_weight=args.feature_loss_weight,
                    age_loss_weight=args.age_loss_weight,
                    generator_path=args.g_checkpoint,
                    discriminator_path=args.d_checkpoint)
    #,feature_extractor_path=args.feature_extractor_path)
    d_optim = model.d_optim
    g_optim = model.g_optim

    for epoch in range(args.max_epoches):
        avr_d_loss = 0
        avr_g_loss = 0
        count = 0
        for idx, (source_img_227,source_img_128,true_label_img,\
               true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader,1):

            running_d_loss = None
            running_g_loss = None
            n_iter = epoch * len(train_loader) + idx

            #mv to gpu
            source_img_227 = source_img_227.cuda()
            source_img_128 = source_img_128.cuda()
            true_label_img = true_label_img.cuda()
            true_label_128 = true_label_128.cuda()
            true_label_64 = true_label_64.cuda()
            fake_label_64 = fake_label_64.cuda()
            true_label = true_label.cuda()

            #train discriminator
            for d_iter in range(args.d_iter):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                d_loss = model.d_loss
                running_d_loss = d_loss
                d_loss.backward()
                d_optim.step()

            #visualize params
            for name, param in model.discriminator.named_parameters():
                writer.add_histogram("discriminator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            #train generator
            for g_iter in range(args.g_iter):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                g_loss = model.g_loss
                running_g_loss = g_loss
                g_loss.backward()
                g_optim.step()

            for name, param in model.generator.named_parameters():
                writer.add_histogram("generator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
            logger.info(
                format_str %
                (idx, len(train_loader), running_g_loss, running_d_loss))

            writer.add_scalars('data/loss', {
                'G_loss': running_g_loss,
                'D_loss': running_d_loss
            }, n_iter)
            avr_g_loss += running_g_loss
            avr_d_loss += running_d_loss
            count += 1

            # save the parameters at the end of each save interval
            if idx % args.save_interval == 0:
                model.save_model(dir=args.saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth' %
                                 (epoch, idx))
                logger.info('checkpoint has been created!')

            #val step
            if idx % args.val_interval == 0:
                save_dir = os.path.join(args.saved_validation_folder,
                                        "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128,
                              true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),
                               filename=os.path.join(
                                   save_dir,
                                   "batch_%d_source.jpg" % (val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(args.age_groups):
                        img = model.test_generate(source_img_128,
                                                  true_label_128[age])
                        save_image(Reverse_zero_center()(img),
                                   filename=os.path.join(
                                       save_dir, "batch_%d_age_group_%d.jpg" %
                                       (val_idx, age)))
                        # if epoch % 3 == 0 and idx % 1000 == 0:
                        #     print(post_image(
                        #         filename=os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)),
                        #         token='xoxp-66111612183-66103666016-826666478608-af2a1c301014db145d3cf92d02b9bdcf',
                        #         channels='CPX0UMK42'))
                logger.info('validation image has been created!')
        avr_d_loss = avr_d_loss / count
        avr_g_loss = avr_g_loss / count

        content = "[202.*.*.150.]  [INFO] Epoch End : " + str(
            epoch) + ", d_loss : " + str(avr_d_loss) + ", g_loss : " + str(
                avr_g_loss)
        payload = {"text": content}
        requests.post(webhook_url,
                      data=json.dumps(payload),
                      headers={'Content-Type': 'application/json'})
示例#4
0
                    type=str,
                    help='the path of folder which stores the val img',
                    default='./checkpoint/IPCGANS/%s/validation/' %
                    (TIMESTAMP))
parser.add_argument('--tensorboard_log_folder',
                    type=str,
                    help='the path of folder which stores the tensorboard log',
                    default='./checkpoint/IPCGANS/%s/tensorboard/' %
                    (TIMESTAMP))

args = parser.parse_args()

# define tensorboard
writer = SummaryWriter(os.path.join(args.tensorboard_log_folder, TIMESTAMP))

check_dir(args.checkpoint)
check_dir(args.saved_model_folder)
check_dir(args.saved_validation_folder)

#step2: define logging output
logger = logging.getLogger("IPCGANS Train")
file_handler = logging.FileHandler(join(args.checkpoint, 'log.txt'), "w")
stdout_handler = logging.StreamHandler()
logger.addHandler(file_handler)
logger.addHandler(stdout_handler)
stdout_handler.setFormatter(
    logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
file_handler.setFormatter(
    logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
logger.setLevel(logging.INFO)
示例#5
0
def mysetting(max_epochs,
              gan_loss_weight=75,
              learning_rate=1e-4,
              batch_size=32,
              dcheck_point=None,
              gcheck_point=None,
              checkpoint=None,
              feature_loss_weight=0.5e-4,
              age_loss_weight=30):
    # step1: define argument
    parser = argparse.ArgumentParser(description='train IPCGANS')
    TIMESTAMP = "{0:%Y-%m-%d_%H-%M-%S}".format(datetime.now())

    if not checkpoint:
        checkpoint = './checkpoint/IPCGANS/%s/saved_parameters/' % (TIMESTAMP)
        val_dir = './checkpoint/IPCGANS/%s/validation/' % (TIMESTAMP)
        save_dir = './checkpoint/IPCGANS/%s/saved_parameters/' % (TIMESTAMP)
        tensorboard_dir = './checkpoint/IPCGANS/%s/tensorboard/' % (TIMESTAMP)
    else:
        val_dir = checkpoint + '/%s/validation/' % (TIMESTAMP)
        save_dir = checkpoint + '/%s/saved_parameters/' % (TIMESTAMP)
        tensorboard_dir = checkpoint + '/%s/tensorboard/' % (TIMESTAMP)

    # Optimizer
    parser.add_argument('--learning_rate',
                        '--lr',
                        type=float,
                        help='learning rate',
                        default=learning_rate)
    parser.add_argument('--batch_size',
                        '--bs',
                        type=int,
                        help='batch size',
                        default=batch_size)
    parser.add_argument('--max_epoches',
                        type=int,
                        help='Number of epoches to run',
                        default=max_epochs)
    parser.add_argument('--val_interval',
                        type=int,
                        help='Number of steps to validate',
                        default=1000)
    parser.add_argument('--save_interval',
                        type=int,
                        help='Number of batches to save model',
                        default=500)

    parser.add_argument('--d_iter',
                        type=int,
                        help='Number of steps for discriminator',
                        default=1)
    parser.add_argument('--g_iter',
                        type=int,
                        help='Number of steps for generator',
                        default=2)
    # Model
    parser.add_argument('--gan_loss_weight',
                        type=float,
                        help='gan_loss_weight',
                        default=gan_loss_weight)
    parser.add_argument('--feature_loss_weight',
                        type=float,
                        help='fea_loss_weight',
                        default=feature_loss_weight)
    parser.add_argument('--age_loss_weight',
                        type=float,
                        help='age_loss_weight',
                        default=age_loss_weight)
    parser.add_argument('--age_groups',
                        type=int,
                        help='the number of different age groups',
                        default=5)
    parser.add_argument(
        '--age_classifier_path',
        type=str,
        help='directory of age classification model',
        default=
        './checkpoint/pretrain_alexnet/saved_parameters/age_cls_epoch_47_iter_0.pth'
    )
    # parser.add_argument('--feature_extractor_path', type=str, help='directory of pretrained alexnet', default='/home/guyuchao/Dataset/Pretrain Model/alexnet-owt-4df8aa71.pth')

    # Data and IO
    parser.add_argument('--d_checkpoint',
                        type=str,
                        help='pretrained checkpoints directory',
                        default=dcheck_point)
    parser.add_argument('--g_checkpoint',
                        type=str,
                        help='pretrained checkpoints directory',
                        default=gcheck_point)
    parser.add_argument('--checkpoint',
                        type=str,
                        help='logs and checkpoints directory',
                        default=checkpoint)
    parser.add_argument(
        '--saved_model_folder',
        type=str,
        help='the path of folder which stores the parameters file',
        default=save_dir)
    parser.add_argument('--saved_validation_folder',
                        type=str,
                        help='the path of folder which stores the val img',
                        default=val_dir)
    parser.add_argument(
        '--tensorboard_log_folder',
        type=str,
        help='the path of folder which stores the tensorboard log',
        default=tensorboard_dir)

    args = parser.parse_args()

    # define tensorboard
    writer = SummaryWriter(os.path.join(args.tensorboard_log_folder,
                                        TIMESTAMP))

    check_dir(args.checkpoint)
    check_dir(args.saved_model_folder)
    check_dir(args.saved_validation_folder)

    # step2: define logging output
    logger = logging.getLogger("IPCGANS Train")
    file_handler = logging.FileHandler(join(args.checkpoint, 'log.txt'), "a")
    stdout_handler = logging.StreamHandler()
    logger.addHandler(file_handler)
    logger.addHandler(stdout_handler)
    stdout_handler.setFormatter(
        logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
    file_handler.setFormatter(
        logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
    logger.setLevel(logging.INFO)

    return args, writer, logger