def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0], \
                                   {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in
                                    config.MASK_TYPES}, \
                                   resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                   random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                   random_ff_setting=config.RANDOM_FF_SETTING)
    train_loader = train_dataset.loader(batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=16,
                                        pin_memory=True)

    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \
                                 {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in
                                  ('val',)}, \
                                 resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                 random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                 random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    # print(len(val_loader))

    ### Generate a new val data
    val_datas = []
    j = 0
    for i, data in enumerate(val_loader):
        if j < config.STATIC_VIEW_SIZE:
            imgs = data[0]
            if imgs.size(1) == 3:
                val_datas.append(data)
                j += 1
        else:
            break
    # val_datas = [(imgs, masks) for imgs, masks in val_loader]

    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    netG = InpaintSANet()
    netD = InpaintSADirciminator()

    if config.MODEL_RESTORE != '':
        whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        netG.load_state_dict(netG_state_dict)
        netD.load_state_dict(netD_state_dict)
        logger.info("Loading pretrained models from {} ...".format(
            config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Training...")
    epoch = 50

    for i in range(epoch):
        # validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_loader, i, device=cuda0)

        # train data
        train(netG,
              netD,
              gan_loss,
              recon_loss,
              dis_loss,
              optG,
              optD,
              train_loader,
              i,
              device=cuda0,
              val_datas=val_datas)

        # validate
        validate(netG,
                 netD,
                 gan_loss,
                 recon_loss,
                 dis_loss,
                 optG,
                 optD,
                 val_datas,
                 i,
                 device=cuda0)

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.to(cpu0).state_dict(),
            'netD_state_dict': netD.to(cpu0).state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1))
        torch.save(saved_model,
                   '{}/latest_ckpt.pth.tar'.format(log_dir, i + 1))
Esempio n. 2
0
def main(args):
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    dataset_type = args.dataset

    # Dataset setting
    train_dataset = InpaintDataset(args.train_image_list,\
                                      {'val':args.train_mask_list},
                                      mode='train', img_size=args.img_shape)
    train_loader = train_dataset.loader(batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=4,
                                        pin_memory=True)

    val_dataset = InpaintDataset(args.val_image_list,\
                                      {'val':args.val_mask_list},
                                      # {'val':args.val_mask_list},
                                      mode='val', img_size=args.img_shape)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)

    # Define the Network Structure
    netG = InpaintSANet()
    netD = InpaintSADirciminator()
    netG.cuda()
    netD.cuda()

    if args.load_weights != '':
        whole_model_path = args.load_weights
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        # netG.load_state_dict(netG_state_dict)
        load_consistent_state_dict(netG_state_dict, netG)
        netD.load_state_dict(netD_state_dict)

    # Define loss
    recon_loss = ReconLoss(*([1.2, 1.2, 1.2, 1.2]))
    gan_loss = SNGenLoss(0.005)
    dis_loss = SNDisLoss()
    lr, decay = args.learning_rate, 0.0
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    best_score = 0

    # Create loss and acc file
    loss_writer = csv.writer(open(os.path.join(args.logdir, 'loss.csv'), 'w'),
                             delimiter=',')
    acc_writer = csv.writer(open(os.path.join(args.logdir, 'acc.csv'), 'w'),
                            delimiter=',')

    # Start Training
    for i in range(args.epochs):
        #train data
        train(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD,
              train_loader, i + 1, args.img_shape, loss_writer)

        # validate
        output_dir = os.path.join(args.result_dir, str(i + 1))
        mse, ssim = validate(netG, val_loader, args.img_shape, output_dir,
                             args.gt_dir)
        score = 1 - mse / 100 + ssim
        print('MSE: ', mse, '     SSIM:', ssim, '     SCORE:', score)
        acc_writer.writerow([i + 1, mse, ssim, score])

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(args.logdir, i + 1))
        if score > best_score:
            torch.save(saved_model,
                       '{}/best_ckpt.pth.tar'.format(args.logdir, i + 1))
            best_score = score
            print('New best score at epoch', i + 1)