示例#1
0
文件: run.py 项目: lkampoli/SimulAI
def train_wrapper(model):
    if args.pretrained_model:
        model.load(args.pretrained_model)
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        args.dataset_name,
        args.train_data_paths,
        args.valid_data_paths,
        args.batch_size,
        args.img_width,
        seq_length=args.total_length,
        is_training=True)

    eta = args.sampling_start_value

    for itr in range(1, args.max_iterations + 1):
        print("Iter number:", itr)
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, args.patch_size)

        eta, real_input_flag = schedule_sampling(eta, itr)

        trainer.train(model, ims, real_input_flag, args, itr)

        if itr % args.snapshot_interval == 0:
            model.save(itr)

        if itr % args.test_interval == 0:
            trainer.test(model, test_input_handle, args, itr)

        train_input_handle.next()
示例#2
0
文件: run.py 项目: lkampoli/SimulAI
def test_wrapper(model):
    model.load(args.pretrained_model)
    test_input_handle = datasets_factory.data_provider(
        args.dataset_name,
        args.train_data_paths,
        args.valid_data_paths,
        args.batch_size,
        args.img_width,
        seq_length=args.total_length,
        is_training=False)
    trainer.test(model, test_input_handle, args, 'test_result')
示例#3
0
def train_wrapper(model):
    if args.pretrained_model:
        model.load(args.pretrained_model)
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        args.dataset_name,
        args.train_data_paths,
        args.valid_data_paths,
        args.batch_size,
        args.img_width,
        seq_length=args.total_length,
        is_training=True,
    )

    eta = args.sampling_start_value
    best_valLoss = 999999999999
    best_ssim = -1
    best_psnr = -1
    for itr in range(1, args.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, args.patch_size)

        if args.reverse_scheduled_sampling == 1:
            real_input_flag = reserve_schedule_sampling_exp(itr)
        else:
            eta, real_input_flag = schedule_sampling(eta, itr)

        trainer.train(model, ims, real_input_flag, args, itr)

        if itr % args.snapshot_interval == 0:
            model.save(itr)
        else:
            model.save("latest")
        if itr % args.test_interval == 0:
            val_loss, ssim, psnr = trainer.test(model, test_input_handle, args,
                                                itr)
            if best_ssim < ssim:
                best_ssim = ssim
                model.save("bestssim")
                print("Best SSIM found: {}".format(best_ssim))
            elif best_psnr < psnr:
                best_psnr = psnr
                model.save("bestpsnr")
                print("Best PSNR found: {}".format(best_psnr))
            elif best_valLoss > val_loss:
                best_valLoss = val_loss
                model.save("bestvalloss")
                print("Best ValLossMSE found: {}".format(best_valLoss))
        train_input_handle.next()