Example #1
0
    print '<Main> Epoch {}'.format(epoch)

    if param_updater.check(epoch):
        ## update which parameters are updated
        transfer = param_updater.refresh(epoch)
        print 'Updating params: ', epoch, transfer
        ## get a new trainer with different learnable parameters
        trainer = pipeline.ComposerTrainer(model,
                                           train_loader,
                                           args.lr,
                                           args.lights_mult,
                                           args.un_mult,
                                           args.lab_mult,
                                           transfer,
                                           epoch_size=args.epoch_size,
                                           iters=args.iters)

    if args.save_model:
        state = model.state_dict()
        torch.save(state, open(os.path.join(args.save_path, 'state.t7'), 'w'))

    ## visualize intrinisc image predictions and reconstructions of the val set
    val_losses = pipeline.visualize_composer(model, val_loader, args.save_path,
                                             epoch)

    ## one sweep through the args.epoch_size images
    train_losses = trainer.train()

    ## save plots of the errors
    logger.update(train_losses, val_losses)
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\BOLD\\',
                        help='base folder of datasets')
    parser.add_argument('--mode', type=list, default=['val', 'test'])
    parser.add_argument(
        '--save_path',
        type=str,
        default='logs\\composer\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=16,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--train_set_size',
                        type=int,
                        default=10170,
                        help='number of images in an epoch')
    parser.add_argument('--num_epochs', type=int, default=60)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--checkpoint', type=bool, default=False)
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101

    shader = U_Net.Shader()
    shader.load_state_dict(torch.load('logs/shader/shader_state_59.t7'))
    reflection = U_Net.Reflection()
    reflection.load_state_dict(torch.load('reflection_state.t7'))
    composer = U_Net.Composer(reflection, shader).to(device)

    if args.checkpoint:
        composer.load_state_dict(torch.load(args.state_dict))
        print('load checkpoint success!')

    train_set = pipeline.BOLD_Dataset(args.data_path,
                                      size_per_dataset=args.train_set_size,
                                      mode=args.mode[0])
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    val_set = pipeline.BOLD_Dataset(args.data_path,
                                    size_per_dataset=20,
                                    mode=args.mode[1])
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             num_workers=args.loader_threads,
                                             shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)
    optimizer = optim.Adam(composer.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[20, 40])

    dummy_input = torch.rand(3, 3, 512, 512).to(device)
    writer.add_graph(composer, dummy_input)
    step = 0
    for epoch in range(args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        trainer = pipeline.ComposerTrainer(composer, train_loader, device,
                                           optimizer, writer, step)
        step = trainer.train()
        if args.save_model:
            state = composer.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path,
                             'composer_state_{}.t7'.format(epoch)))

        loss = pipeline.visualize_composer(
            composer, val_loader, device,
            os.path.join(args.save_path, '{}.png'.format(epoch)))
        writer.add_scalar('test_recon_loss', loss[0], step)
        writer.add_scalar('test_refl_loss', loss[1], step)
        writer.add_scalar('test_sha_loss', loss[2], step)
        scheduler.step()