Пример #1
0
def main():
    # 파라미터 로드
    args = parse_args()

    # 리소스 로드
    if torch.cuda.is_available():
        device = torch.device(args.device)
    else:
        device = torch.device("cpu")
    model = MIM(args).to(device)
    print(model)
    print('The model is loaded!\n')

    # 데이터셋 로드
    train_input_handle, test_input_handle = datasets_factory.data_provider(args.dataset_name,
                                                                           args.train_data_paths,
                                                                           args.valid_data_paths,
                                                                           args.batch_size * args.n_gpu,
                                                                           args.img_width,
                                                                           seq_length=args.total_length,
                                                                           is_training=True)  # n 64 64 1 로 나옴

    # with torch.set_grad_enabled(True):
    if args.pretrained_model:
        model.load(args.pretrained_model)

    eta = args.sampling_start_value  # 1.0

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    MSELoss = torch.nn.MSELoss()

    for itr in range(1, args.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)

        ims = train_input_handle.get_batch()
        ims_reverse = None
        if args.reverse_img:
            ims_reverse = ims[:, :, :, ::-1]
            ims_reverse = preprocess.reshape_patch(ims_reverse, args.patch_size)
        ims = preprocess.reshape_patch(ims, args.patch_size)
        eta, real_input_flag = schedule_sampling(eta, itr, args)

        loss = trainer.trainer(model, ims, real_input_flag, args, itr, ims_reverse, device, optimizer, MSELoss)

        if itr % args.snapshot_interval == 0:
            model.save(itr)

        if itr % args.test_interval == 0:
            trainer.test(model, test_input_handle, args, itr)

        if itr % args.display_interval == 0:
            print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr))
            print('training loss: ' + str(loss))

        train_input_handle.next()

        del loss
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel], FLAGS.seq_length)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr

    delta = 0.0000125
    base = 0.99998
    eta = 1
    # eta = 0.5
#%%
    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        print('train get_batch:')
        ims, filename = train_input_handle.get_batch(False)
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 80000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length))
        true_token = (random_flip < eta)
        ones = np.ones((FLAGS.img_height/FLAGS.patch_size,
                        FLAGS.img_width/FLAGS.patch_size,
                        FLAGS.patch_size**2*FLAGS.img_channel))
        zeros = np.zeros((FLAGS.img_height/FLAGS.patch_size,
                          FLAGS.img_width/FLAGS.patch_size,
                          FLAGS.patch_size**2*FLAGS.img_channel))
        
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length):
                # 0 2 4 6 8 10
                if (j % 2 == 0):
                    mask_true.append(ones)
                # if iteration bigger it will random mask 1 3 5 7 9
                else:
                    if true_token[i, j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
                    
                    
#                 if j < FLAGS.input_length or FLAGS.seq_length-1-j < FLAGS.input_length:
#                     mask_true.append(ones)
#                 else:
#                     if true_token[i,j-10]:
#                         mask_true.append(ones)
#                     else:
#                         mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(mask_true, (FLAGS.batch_size,
                                           FLAGS.seq_length,
                                           FLAGS.img_height/FLAGS.patch_size,
                                           FLAGS.img_width/FLAGS.patch_size,
                                           FLAGS.patch_size**2*FLAGS.img_channel))
        ###cost = model.train(ims, lr, mask_true)

        if FLAGS.reverse_input:
            ims_rev = ims[:,::-1]
            ###cost += model.train(ims_rev, lr, mask_true)
            ###cost = cost/2

        cost = model.train(ims, ims_rev, lr, mask_true, itr)
        #tf.summary.scalar('cost', cost)
        

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle = False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse,ssim,psnr,fmae,sharp= [],[],[],[],[]
            for i in range(FLAGS.seq_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)

            mask_true = np.ones((FLAGS.batch_size,
                                    FLAGS.seq_length,
                                    FLAGS.img_height,
                                    FLAGS.img_width,
                                    FLAGS.img_channel))
            for num_batch in range(FLAGS.batch_size):
                for num_seq in range(FLAGS.seq_length):
                    # 0 2 4 6 8 10 skip
                    if (num_seq % 2 == 0):
                        continue
                    # 1 3 5 7 9 replace random noise
                    else:
                        mask_true[num_batch,num_seq] = np.zeros((
                                FLAGS.img_height,
                                FLAGS.img_width,
                                FLAGS.img_channel))
                        
            mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size)
            ###while(test_input_handle.no_batch_left() == False):
            while(batch_id <= 10):
                batch_id = batch_id + 1
                print('test get_batch:')
                test_ims, filename = test_input_handle.get_batch(False)
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)

                if FLAGS.reverse_input:
                    test_ims_rev = test_dat[:,::-1]

                img_gen, ims_watch, ims_rev_watch = model.test(test_dat, test_ims_rev, mask_true, itr)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
                ims_watch = np.concatenate(ims_watch)
                ims_watch = preprocess.reshape_patch_back(ims_watch, FLAGS.patch_size)
                ims_rev_watch = np.concatenate(ims_rev_watch)
                ims_rev_watch = preprocess.reshape_patch_back(ims_rev_watch, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length):
                    x = test_ims[:,i,:,:,0]
                    
                    # Predict only odd images
                    if FLAGS.gen_num == 5:
                        if (i % 2 == 1):
                            gx = img_gen[:,i//2,:,:,0]
                        else:
                            gx = test_ims[:,i,:,:,0]
                    # Predict 11 images
                    elif FLAGS.gen_num == 11:
                        if (i % 2 == 1):
                            gx = img_gen[:,i,:,:,0]
                        else:
                            gx = test_ims[:,i,:,:,0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b],3)))
                        score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(filename))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, img_gt)
                        
                    for i in range(FLAGS.seq_length):
                        name = 'pd' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        
                        # Predict only odd images
                        if FLAGS.gen_num == 5:
                            if (i % 2 == 1):
                                img_pd = img_gen[0,i//2,:,:,:]
                            else:
                                img_pd = test_ims[0,i,:,:,:]
                        # Predict 11 images
                        elif FLAGS.gen_num == 11:
                            img_pd = img_gen[0,i,:,:,:]
                        
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                        name = 'zwgt' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        img_zwgt = np.uint8(ims_watch[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, img_zwgt)
                        name = 'zwgtrev' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        #print('ims_rev_watch shape =',ims_rev_watch.shape)
                        zwgtrev = np.uint8(ims_rev_watch[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, zwgtrev)
                
                
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id*FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length):
                print(img_mse[i] / (batch_id*FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32)/batch_id
            fmae = np.asarray(fmae, dtype=np.float32)/batch_id
            ssim = np.asarray(ssim, dtype=np.float32)/(FLAGS.batch_size*batch_id)
            sharp = np.asarray(sharp, dtype=np.float32)/(FLAGS.batch_size*batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            for i in range(FLAGS.seq_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            for i in range(FLAGS.seq_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Пример #3
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, FLAGS.img_width)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        ones = np.ones((FLAGS.img_width / FLAGS.patch_size,
                        FLAGS.img_width / FLAGS.patch_size,
                        FLAGS.patch_size**2 * FLAGS.img_channel))
        zeros = np.zeros((FLAGS.img_width / FLAGS.patch_size,
                          FLAGS.img_width / FLAGS.patch_size,
                          FLAGS.patch_size**2 * FLAGS.img_channel))
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                if true_token[i, j]:
                    mask_true.append(ones)
                else:
                    mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(
            mask_true,
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
             FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width /
             FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel))
        cost = model.train(ims, lr, mask_true)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_width / FLAGS.patch_size,
                 FLAGS.img_width / FLAGS.patch_size,
                 FLAGS.patch_size**2 * FLAGS.img_channel))
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat, mask_true)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, 0]
                    gx = img_gen[:, i, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        score, _ = compare_ssim(pred_frm[b],
                                                real_frm[b],
                                                full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(batch_id))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i + 1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                        cv2.imwrite(file_name, img_gt)
                    for i in range(FLAGS.seq_length - FLAGS.input_length):
                        name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                        file_name = os.path.join(path, name)
                        img_pd = img_gen[0, i, :, :, :]
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            ssim = np.asarray(ssim,
                              dtype=np.float32) / (FLAGS.batch_size * batch_id)
            sharp = np.asarray(
                sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Пример #4
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name
    # FLAGS.gen_frm_dir += FLAGS.dataset_name
    # if tf.io.gfile.exists(FLAGS.save_dir):
    #     tf.io.gfile.rmtree(FLAGS.save_dir)
    # tf.io.gfile.makedirs(FLAGS.save_dir)
    # if tf.io.gfile.exists(FLAGS.gen_frm_dir):
    #     tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
    # tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length) + FLAGS.num_hidden
    FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}_weighted.ckpt'
    FLAGS.gen_frm_dir += FLAGS.dataset_name
    if not tf.io.gfile.exists(FLAGS.save_dir):
        # tf.io.gfile.rmtree(FLAGS.save_dir)
        tf.io.gfile.makedirs(FLAGS.save_dir)
    else:
        FLAGS.pretrained_model = FLAGS.save_dir
    if not tf.io.gfile.exists(FLAGS.gen_frm_dir):
        # tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
        tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    process_data_dir = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    'process_0.5')
    node_pos_file_2in1 = os.path.join(process_data_dir, 'node_pos_0.5.npy')
    node_pos = np.load(node_pos_file_2in1)

    train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name,
                                    FLAGS.dataset_name + '_training')
    valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    FLAGS.dataset_name + '_validation')
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_file, True, FLAGS.input_length,
        FLAGS.seq_length - FLAGS.input_length)

    cities = ['Berlin', 'Istanbul', 'Moscow']
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.2
    base = 0.99998
    eta = 1
    min_val_loss = 1.0

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        imss = train_input_handle.get_batch()
        imss = imss[..., :FLAGS.img_channel]

        imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width,
                                        FLAGS.patch_size_height)
        num_batches = imss.shape[0]
        for bi in range(0, num_batches, FLAGS.batch_size):
            ims = imss[bi:bi + FLAGS.batch_size]
            FLAGS.img_height = ims.shape[2]
            FLAGS.img_width = ims.shape[3]
            batch_size = ims.shape[0]
            if itr < 10:
                eta -= delta
            else:
                eta = 0.0
            random_flip = np.random.random_sample(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
            true_token = (random_flip < eta)
            ones = np.ones((FLAGS.img_height, FLAGS.img_width,
                            int(FLAGS.patch_size_height *
                                FLAGS.patch_size_width * FLAGS.img_channel)))
            zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width),
                              int(FLAGS.patch_size_height *
                                  FLAGS.patch_size_width * FLAGS.img_channel)))
            mask_true = []
            for i in range(batch_size):
                for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                    if true_token[i, j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
            mask_true = np.array(mask_true)
            mask_true = np.reshape(
                mask_true,
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 int(FLAGS.img_height), int(FLAGS.img_width),
                 int(FLAGS.patch_size_height * FLAGS.patch_size_width *
                     FLAGS.img_channel)))
            cost = model.train(ims, lr, mask_true, batch_size)

            if FLAGS.reverse_input:
                ims_rev = ims[:, ::-1]
                cost += model.train(ims_rev, lr, mask_true, batch_size)
                cost = cost / 2

            # cost = cost / (batch_size * FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height *
            #                FLAGS.patch_size_width * FLAGS.img_channel * (FLAGS.seq_length - 1))

            if itr % FLAGS.display_interval == 0:
                print('itr: ' + str(itr), flush=True)
                print('training loss: ' + str(cost), flush=True)

        train_input_handle.next()
        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            batch_size = len(indicies)
            test_input_handle.begin(do_shuffle=False)
            # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            # os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))

            gt_list = []
            pred_list = []

            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_test_batch(indicies)
                # get the selected channels
                test_ims = test_ims[..., :FLAGS.img_channel]

                gt_list.append(test_ims[:, FLAGS.input_length:, :, :, :])
                test_dat = preprocess.reshape_patch(test_ims,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)

                img_gen = model.test(test_dat, mask_true, batch_size)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height)
                pred_list.append(img_gen)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, :]
                    gx = img_gen[:, i, :, :, :]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)

                test_input_handle.next()

            avg_mse = avg_mse / (batch_id * batch_size * FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel *
                                 len(img_mse))
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] /
                      (batch_id * batch_size * FLAGS.img_height *
                       FLAGS.img_width * FLAGS.patch_size_height *
                       FLAGS.patch_size_width * FLAGS.img_channel),
                      flush=True)

            gt_list = np.stack(gt_list, axis=0)
            pred_list = np.stack(pred_list, axis=0)
            mse = masked_mse_np(pred_list, gt_list, null_val=np.nan)
            volume_mse = masked_mse_np(pred_list[..., 0],
                                       gt_list[..., 0],
                                       null_val=np.nan)
            speed_mse = masked_mse_np(pred_list[..., 1],
                                      gt_list[..., 1],
                                      null_val=np.nan)

            print("The output mse is ", mse, flush=True)
            print("The volume mse is ", volume_mse, flush=True)
            print("The speed mse is ", speed_mse, flush=True)
            if FLAGS.img_channel == 3:
                direction_mse = masked_mse_np(pred_list[..., 2],
                                              gt_list[..., 2],
                                              null_val=np.nan)
                print("The direction mse is ", direction_mse, flush=True)

            if min_val_loss > mse:
                min_val_loss = mse
                print("Current Min Val Loss is ", min_val_loss)
                model.save_to_best_mode()

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
Пример #5
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    train_data_paths = os.path.join(
        FLAGS.train_data_paths, FLAGS.dataset_name,
        'train_speed_down_sample{}.npz'.format(FLAGS.down_sample))
    valid_data_paths = os.path.join(
        FLAGS.valid_data_paths, FLAGS.dataset_name,
        'valid_speed_down_sample{}.npz'.format(FLAGS.down_sample))
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_size, True, FLAGS.down_sample, FLAGS.input_length,
        FLAGS.seq_length - FLAGS.input_length)

    cities = ['Berlin', 'Istanbul', 'Moscow']
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    dims = train_input_handle.dims
    FLAGS.img_height = dims[-2]
    FLAGS.img_width = dims[-1]
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        ones = np.ones((FLAGS.img_height, FLAGS.img_width,
                        int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width),
                          int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                if true_token[i, j]:
                    mask_true.append(ones)
                else:
                    mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(
            mask_true,
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
             int(FLAGS.img_height), int(FLAGS.img_width),
             int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        cost = model.train(ims, lr, mask_true)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr), flush=True)
            print('training loss: ' + str(cost), flush=True)

        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width,
                 FLAGS.patch_size**2 * FLAGS.img_channel))
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat, mask_true)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, 0]
                    gx = img_gen[:, i, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        # score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True)
                        # ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(batch_id))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i + 1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                        cv2.imwrite(file_name, img_gt)
                    for i in range(FLAGS.seq_length - FLAGS.input_length):
                        name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                        file_name = os.path.join(path, name)
                        img_pd = img_gen[0, i, :, :, :]
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            sharp = np.asarray(
                sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i], flush=True)
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i], flush=True)
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i], flush=True)

            # test with file
            valid_data_path = os.path.join(
                FLAGS.train_data_paths, FLAGS.dataset_name,
                '{}_validation'.format(FLAGS.dataset_name))
            files = list_filenames(valid_data_path)
            output_all = []
            labels_all = []
            for f in files:
                valid_file = valid_data_path + '/' + f
                valid_input, raw_output = datasets_factory.test_validation_provider(
                    valid_file,
                    indicies,
                    down_sample=FLAGS.down_sample,
                    seq_len=FLAGS.input_length,
                    horizon=FLAGS.seq_length - FLAGS.input_length)
                valid_input = valid_input.astype(np.float) / 255.0
                labels_all.append(raw_output)
                num_tests = len(indicies)
                num_partitions = int(np.ceil(num_tests / FLAGS.batch_size))
                for i in range(num_partitions):
                    valid_input_i = valid_input[i * FLAGS.batch_size:(i + 1) *
                                                FLAGS.batch_size]
                    num_input_i = valid_input_i.shape[0]
                    if num_input_i < FLAGS.batch_size:
                        zeros_fill_in = np.zeros(
                            (FLAGS.batch_size - num_input_i, FLAGS.seq_length,
                             FLAGS.img_height, FLAGS.img_width,
                             FLAGS.img_channel))
                        valid_input_i = np.concatenate(
                            [valid_input_i, zeros_fill_in], axis=0)
                    img_gen = model.test(valid_input_i, mask_true)
                    output_all.append(img_gen[0][:num_input_i])

            output_all = np.concatenate(output_all, axis=0)
            labels_all = np.concatenate(labels_all, axis=0)
            origin_height = labels_all.shape[-2]
            origin_width = labels_all.shape[-3]
            output_resize = []
            for i in range(output_all.shape[0]):
                output_i = []
                for j in range(output_all.shape[1]):
                    tmp_data = output_all[i, j, 1, :, :]
                    tmp_data = cv2.resize(tmp_data,
                                          (origin_height, origin_width))
                    tmp_data = np.expand_dims(tmp_data, axis=0)
                    output_i.append(tmp_data)
                output_i = np.stack(output_i, axis=0)
                output_resize.append(output_i)
            output_resize = np.stack(output_resize, axis=0)

            output_resize *= 255.0
            labels_all = np.expand_dims(labels_all[..., 1], axis=2)
            valid_mse = masked_mse_np(output_resize, labels_all, np.nan)

            print("validation mse is ", valid_mse, flush=True)

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Пример #6
0
def main(argv=None):

    tf.disable_eager_execution()  #toegevoegd anders error

    # load data
    _, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.test_data_paths,
        FLAGS.batch_size, FLAGS.img_width)

    print("Initializing models")
    model = Model()
    lr = FLAGS.lr

    print('test...')
    test_input_handle.begin(do_shuffle=False)
    res_path = os.path.join(FLAGS.gen_frm_dir, 'test')
    os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)
    mask_true = np.zeros(
        (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
         int(FLAGS.img_height / FLAGS.patch_size),
         int(FLAGS.img_width / FLAGS.patch_size),
         FLAGS.patch_size**2 * FLAGS.img_channel))
    while (test_input_handle.no_batch_left() == False):
        batch_id = batch_id + 1
        test_ims = test_input_handle.get_batch()
        test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
        img_gen = model.test(test_dat, mask_true)

        # concat outputs of different gpus along batch
        img_gen = np.concatenate(img_gen)
        img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
        # MSE per frame
        for i in xrange(FLAGS.seq_length - FLAGS.input_length):
            x = test_ims[:, i + FLAGS.input_length, :, :, 0]
            gx = img_gen[:, i, :, :, 0]
            fmae[i] += metrics.batch_mae_frame_float(gx, x)
            gx = np.maximum(gx, 0)
            gx = np.minimum(gx, 1)
            mse = np.square(x - gx).sum()
            img_mse[i] += mse
            avg_mse += mse

            real_frm = np.uint8(x * 255)
            pred_frm = np.uint8(gx * 255)
            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
            for b in xrange(FLAGS.batch_size):
                sharp[i] += np.max(
                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True)
                ssim[i] += score

        # save prediction examples
        if batch_id <= 10:
            path = os.path.join(res_path, str(batch_id))
            os.mkdir(path)
            for i in xrange(FLAGS.seq_length):
                name = 'gt' + str(i + 1) + '.png'
                file_name = os.path.join(path, name)
                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                cv2.imwrite(file_name, img_gt)
            for i in xrange(FLAGS.seq_length - FLAGS.input_length):
                name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                file_name = os.path.join(path, name)
                img_pd = img_gen[0, i, :, :, :]
                img_pd = np.maximum(img_pd, 0)
                img_pd = np.minimum(img_pd, 1)
                img_pd = np.uint8(img_pd * 255)
                cv2.imwrite(file_name, img_pd)
        test_input_handle.next()
    avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
    print('mse per seq: ' + str(avg_mse))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(img_mse[i] / (batch_id * FLAGS.batch_size))
    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    ssim = np.asarray(ssim, dtype=np.float32) / (FLAGS.batch_size * batch_id)
    sharp = np.asarray(sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
    print('psnr per frame: ' + str(np.mean(psnr)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(psnr[i])
    print('fmae per frame: ' + str(np.mean(fmae)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(fmae[i])
    print('ssim per frame: ' + str(np.mean(ssim)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(ssim[i])
    print('sharpness per frame: ' + str(np.mean(sharp)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(sharp[i])
def main(argv=None):
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size,
        [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel],
        FLAGS.seq_length)

    print('Initializing models')
    model = Model()

    #%%
    test_input_handle.begin(do_shuffle=False)
    totalDataLen = int(test_input_handle.total())
    print('totalDataLen=', totalDataLen)
    for itr in range(1):
        print('inference...')
        res_path = os.path.join(FLAGS.gen_frm_dir, 'images' + str(itr))
        os.mkdir(res_path)
        avg_mse = 0
        batch_id = 0
        img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
        for i in range(FLAGS.seq_length):
            img_mse.append(0)
            ssim.append(0)
            psnr.append(0)
            fmae.append(0)
            sharp.append(0)

        mask_true = np.ones(
            (FLAGS.batch_size, FLAGS.seq_length, FLAGS.img_height,
             FLAGS.img_width, FLAGS.img_channel))
        for num_batch in range(FLAGS.batch_size):
            for num_seq in range(FLAGS.seq_length):
                # 0 2 4 6 8 10 skip
                if (num_seq % 2 == 0):
                    continue
                # 1 3 5 7 9 replace random noise
                else:
                    mask_true[num_batch, num_seq] = np.zeros(
                        (FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel))

        mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size)
        ###while(test_input_handle.no_batch_left() == False):
        while (batch_id < totalDataLen):
            batch_id = batch_id + 1
            print('test get_batch:')
            test_ims, fileName = test_input_handle.get_batch(False)
            test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)

            if FLAGS.reverse_input:
                test_ims_rev = test_dat[:, ::-1]

            img_gen, ims_watch, ims_rev_watch = model.test(
                test_dat, test_ims_rev, mask_true, itr)

            # concat outputs of different gpus along batch
            img_gen = np.concatenate(img_gen)
            img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
            ims_watch = np.concatenate(ims_watch)
            ims_watch = preprocess.reshape_patch_back(ims_watch,
                                                      FLAGS.patch_size)
            ims_rev_watch = np.concatenate(ims_rev_watch)
            ims_rev_watch = preprocess.reshape_patch_back(
                ims_rev_watch, FLAGS.patch_size)
            # MSE per frame
            for i in range(FLAGS.seq_length):
                x = test_ims[:, i, :, :, 0]

                # Predict only odd images
                if FLAGS.gen_num == 5:
                    if (i % 2 == 1):
                        gx = img_gen[:, i // 2, :, :, 0]
                    else:
                        gx = test_ims[:, i, :, :, 0]
                # Predict 11 images
                elif FLAGS.gen_num == 11:
                    if (i % 2 == 1):
                        gx = img_gen[:, i, :, :, 0]
                    else:
                        gx = test_ims[:, i, :, :, 0]

                fmae[i] += metrics.batch_mae_frame_float(gx, x)
                gx = np.maximum(gx, 0)
                gx = np.minimum(gx, 1)
                mse = np.square(x - gx).sum()
                img_mse[i] += mse
                avg_mse += mse

                real_frm = np.uint8(x * 255)
                pred_frm = np.uint8(gx * 255)
                psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                for b in range(FLAGS.batch_size):
                    sharp[i] += np.max(
                        cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                    score, _ = compare_ssim(pred_frm[b],
                                            real_frm[b],
                                            full=True)
                    ssim[i] += score

            # save prediction examples
            if batch_id < totalDataLen:
                path = os.path.join(res_path, str(fileName))
                os.mkdir(path)
                for i in range(FLAGS.seq_length):
                    name = 'gt' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)
                    img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                    cv2.imwrite(file_name, img_gt)

                for i in range(FLAGS.seq_length):
                    name = 'pd' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)

                    #                   # Predict only odd images
                    if FLAGS.gen_num == 5:
                        if (i % 2 == 1):
                            img_pd = img_gen[0, i // 2, :, :, :]
                        else:
                            img_pd = test_ims[0, i, :, :, :]
                    # Predict 11 images
                    elif FLAGS.gen_num == 11:
                        img_pd = img_gen[0, i, :, :, :]

                    img_pd = np.maximum(img_pd, 0)
                    img_pd = np.minimum(img_pd, 1)
                    img_pd = np.uint8(img_pd * 255)
                    cv2.imwrite(file_name, img_pd)

            test_input_handle.next()
        avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
        print('mse per seq: ' + str(avg_mse))
        for i in range(FLAGS.seq_length):
            print(img_mse[i] / (batch_id * FLAGS.batch_size))
        psnr = np.asarray(psnr, dtype=np.float32) / batch_id
        fmae = np.asarray(fmae, dtype=np.float32) / batch_id
        ssim = np.asarray(ssim,
                          dtype=np.float32) / (FLAGS.batch_size * batch_id)
        sharp = np.asarray(sharp,
                           dtype=np.float32) / (FLAGS.batch_size * batch_id)
        print('psnr per frame: ' + str(np.mean(psnr)))
        for i in range(FLAGS.seq_length):
            print(psnr[i])
        print('fmae per frame: ' + str(np.mean(fmae)))
        for i in range(FLAGS.seq_length):
            print(fmae[i])
        print('ssim per frame: ' + str(np.mean(ssim)))
        for i in range(FLAGS.seq_length):
            print(ssim[i])
        print('sharpness per frame: ' + str(np.mean(sharp)))
        for i in range(FLAGS.seq_length):
            print(sharp[i])
def main(argv=None):

    heading_dict = {1: 1, 2:85, 3: 170, 4: 255, 0:0}
    heading = FLAGS.heading
    FLAGS.save_dir += FLAGS.dataset_name + str(FLAGS.seq_length) + FLAGS.num_hidden + 'squash' + 'L1+L2+VALID' + 'multi-task'
    FLAGS.gen_frm_dir += FLAGS.dataset_name
    if not tf.io.gfile.exists(FLAGS.save_dir):
        # tf.io.gfile.rmtree(FLAGS.save_dir)
        tf.io.gfile.makedirs(FLAGS.save_dir)
    else:
        FLAGS.pretrained_model = FLAGS.save_dir
    if not tf.io.gfile.exists(FLAGS.gen_frm_dir):
        # tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
        tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_training')
    valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_validation')
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_file, True, FLAGS.input_length, FLAGS.seq_length - FLAGS.input_length)

    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    heading_table = np.array([[0, 0], [-1, 1], [1, 1], [-1, -1], [1, -1]], dtype=np.float32)

    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        imss = train_input_handle.get_batch()

        # # print("imss shape is ", imss.shape)
        tem_data = imss.copy()
        heading_image = imss[:, :, :, :, 2]*255
        heading_image = (heading_image // 85).astype(np.int8) + 1
        heading_image[tem_data[:, :, :, :, 2] == 0] = 0
        heading_image = heading_table[heading_image]

        speed_on_axis = np.expand_dims(imss[:, :, :, :, 1] / np.sqrt(2), axis=-1)
        imss = speed_on_axis * heading_image

        imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width, FLAGS.patch_size_height)
        num_batches = imss.shape[0]
        for bi in range(0, num_batches, FLAGS.batch_size):
            ims = imss[bi:bi+FLAGS.batch_size]
            FLAGS.img_height = ims.shape[2]
            FLAGS.img_width = ims.shape[3]
            batch_size = ims.shape[0]
            if itr < 50000:
                eta -= delta
            else:
                eta = 0.0
            random_flip = np.random.random_sample(
                (batch_size, FLAGS.seq_length-FLAGS.input_length-1))
            true_token = (random_flip < eta)
            #true_token = (random_flip < pow(base,itr))
            ones = np.ones((FLAGS.img_height,
                            FLAGS.img_width,
                            int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            zeros = np.zeros((int(FLAGS.img_height),
                              int(FLAGS.img_width),
                              int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            mask_true = []
            for i in range(batch_size):
                for j in range(FLAGS.seq_length-FLAGS.input_length-1):
                    if true_token[i,j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
            mask_true = np.array(mask_true)
            mask_true = np.reshape(mask_true, (batch_size,
                                               FLAGS.seq_length-FLAGS.input_length-1,
                                               int(FLAGS.img_height),
                                               int(FLAGS.img_width),
                                               int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            cost, _ = model.train(ims, lr, mask_true, batch_size)

            if FLAGS.reverse_input:
                ims_rev = ims[:,::-1]
                cost2, _ = model.train(ims_rev, lr, mask_true, batch_size)
                cost = (cost + cost2) / 2

            if itr % FLAGS.display_interval == 0:
                print('itr: ' + str(itr), flush=True)
                print('training loss: ' + str(cost), flush=True)

        train_input_handle.next()
        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            epsilon = 0.2
            batch_size = len(indicies)
            test_input_handle.begin(do_shuffle = False)
            # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            # os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            gt_list = []
            pred_list = []
            pred_list_all = []
            pred_vec = []
            move_avg = []
            img_mse, ssim, psnr, fmae, sharp= [],[],[],[],[]
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros((batch_size,
                                  FLAGS.seq_length-FLAGS.input_length-1,
                                  FLAGS.img_height,
                                  FLAGS.img_width,
                                  FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel))
            while(test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_test_batch(indicies)
                # get the ground truth
                gt_list.append(test_ims[:, FLAGS.input_length:, :, :, 1:])
                # cvt the heading to 0, 1, 2, 3, 4
                tem_data = test_ims.copy()
                heading_image = test_ims[:, :, :, :, 2] * 255
                heading_image = (heading_image // 85).astype(np.int8) + 1
                heading_image[tem_data[:, :, :, :, 2] == 0] = 0
                cvt_heading = heading_image.copy()

                # convert the data into speed vectors
                heading_selected = np.zeros_like(heading_image, np.int8)
                heading_selected[heading_image == heading] = heading
                heading_image = heading_selected
                heading_image = heading_table[heading_image]
                speed_on_axis = np.expand_dims(test_ims[:, :, :, :, 1] / np.sqrt(2), axis=-1)
                test_ims = speed_on_axis * heading_image

                # mavg filtered results
                mavg_results_all = cast_moving_avg(tem_data[:, :FLAGS.input_length, ...])
                mavg_results = np.zeros_like(mavg_results_all)
                # heading_image = np.expand_dims(heading_image, axis=-1)
                mavg_results[cvt_heading[:, FLAGS.input_length:, ...] == heading] = \
                    mavg_results_all[cvt_heading[:, FLAGS.input_length:, ...] == heading]
                move_avg.append(mavg_results)

                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size_width, FLAGS.patch_size_height)
                img_gen = model.test(test_dat, mask_true, batch_size)
                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                # reshape the prediction has ndims=5
                img_gen = np.reshape(img_gen, (img_gen.shape[0], FLAGS.seq_length - FLAGS.input_length,
                                               FLAGS.img_height, FLAGS.img_width, -1))
                img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height)
                # print("Image Generates Shape is ", img_gen.shape)
                # MSE per frame

                img_gen_list = []
                img_gen_origin_list = []
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = tem_data[:,i + FLAGS.input_length,:,:, 1:]
                    gx = img_gen[:,i,:, :, :]

                    # print("img_gen shape is ", gx.shape)
                    val_results_speed = np.sqrt(gx[..., 0] ** 2 + gx[..., 1] ** 2)
                    # print("val speed: ", val_results_speed, flush=True)
                    val_results_heading = np.zeros_like(gx[..., 1])
                    val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] > 0)] = 85.0 / 255.0
                    val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] < 0)] = 255.0 / 255.0
                    val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] < 0)] = 170.0 / 255.0
                    val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] > 0)] = 1.0 / 255.0

                    gen_speed_heading = np.stack([val_results_speed, val_results_heading], axis=-1)
                    img_gen_origin_list.append(gen_speed_heading)

                    # Transformation according to moving average direction when mavg speed is small
                    val_results_heading[mavg_results[:, i, :, :, 1] < epsilon] = \
                        mavg_results[:, i, :, :, 2][mavg_results[:, i, :, :, 1] < epsilon]
                    gx = np.stack([val_results_speed, val_results_heading], axis=-1)
                    img_gen_list.append(gx)

                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                img_gen_list = np.stack(img_gen_list, axis=1)
                img_gen_origin_list = np.stack(img_gen_origin_list, axis=1)
                pred_list_all.append(img_gen_origin_list)
                pred_list.append(img_gen_list)
                pred_vec.append(img_gen)
                test_input_handle.next()

            avg_mse = avg_mse / (batch_id*batch_size*FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel * len(img_mse))
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id*batch_size*FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel))

            gt_list_all = np.stack(gt_list, axis=0)
            # GT filtered to the direction required
            gt_list = np.zeros_like(gt_list_all)
            gt_list[gt_list_all[..., 1]*255 == heading_dict[heading]] = \
                gt_list_all[gt_list_all[..., 1]*255 == heading_dict[heading]]

            pred_list = np.stack(pred_list, axis=0)
            pred_list_all = np.stack(pred_list_all, axis=0)

            print("Evaluate on every pixels....")
            mse = masked_mse_np(pred_list, gt_list, null_val=np.nan)
            speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=np.nan)
            direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=np.nan)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for Transformation...")
            mse = masked_mse_np(pred_list, gt_list, null_val=0.0)
            speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for No Transformation...")
            mse = masked_mse_np(pred_list_all, gt_list, null_val=0.0)
            speed_mse = masked_mse_np(pred_list_all[..., 0], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(pred_list_all[..., 1], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for MAVG...")
            # Evaluate on large gt speeds for direction
            move_avg = np.stack(move_avg, axis=0)
            mse = masked_mse_np(move_avg[..., 1:], gt_list, null_val=0.0)
            speed_mse = masked_mse_np(move_avg[..., 1], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            large_gt_speed = move_avg[..., 1] >= epsilon
            move_avg[..., 2][large_gt_speed] = pred_list_all[large_gt_speed, 1]
            direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0)
            print(f"The direction of combined mavg and large speed~({epsilon}) prediction is ", direction_mse)

            direction_mse = masked_mse_np(pred_list_all[large_gt_speed, 1], gt_list[large_gt_speed, 1], null_val=0.0)
            print("The direction mse on large speed gt is ", direction_mse)


        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
Пример #9
0
def main():
    # 파라미터 로드
    args = parse_args()

    # 리소스 로드
    if torch.cuda.is_available():
        device = torch.device(args.device)
    else:
        device = torch.device("cpu")

    model = MIM(args).to(device)
    print(model)
    print('The model is loaded!\n')

    # 데이터셋 로드
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        args.dataset_name,
        args.train_data_paths,
        args.valid_data_paths,
        args.batch_size * args.n_gpu,
        args.img_width,
        seq_length=args.total_length,
        is_training=True)  # n 64 64 1 로 나옴
    gen_images = None

    cell_state = [init_state(args) for i in range(4)]
    hidden_state = [init_state(args) for i in range(4)]
    cell_state_diff = [init_state(args) for i in range(3)]
    hidden_state_diff = [init_state(args) for i in range(3)]
    st_memory = init_state(args)
    conv_lstm_c = init_state(args)

    MIMB_ct_weight = nn.Parameter(
        torch.randn((args.num_hidden[0] * 2, args.img_height, args.img_width),
                    device=device))
    MIMB_oc_weight = nn.Parameter(
        torch.randn((args.num_hidden[0], args.img_height, args.img_width),
                    device=device))
    MIMN_ct_weight = nn.Parameter(
        torch.randn((args.num_hidden[0] * 2, args.img_height, args.img_width),
                    device=device))
    MIMN_oc_weight = nn.Parameter(
        torch.randn((args.num_hidden[0], args.img_height, args.img_width),
                    device=device))

    if args.pretrained_model:
        hidden_state, cell_state, hidden_state_diff, cell_state_diff, st_memory, conv_lstm_c, MIMB_ct_weight, \
        MIMB_oc_weight, MIMN_ct_weight, MIMN_oc_weight = loadVariables(args)
        model.load(args.pretrained_model)

    eta = args.sampling_start_value  # 1.0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # MSELoss = torch.nn.MSELoss()

    for itr in range(1, args.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)

        # input data 가져오기 및 reshape_patch 적용
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, args.patch_size)
        eta, real_input_flag = schedule_sampling(eta, itr, args)

        # ims tensor로 변경하고 gpu 설정해줌
        # 전방 전파 후 예측 영상과 gt 영상 생성
        ims_tensor = torch.tensor(ims, device=device)
        gen_images = model.forward(ims_tensor, real_input_flag, hidden_state,
                                   cell_state, hidden_state_diff,
                                   cell_state_diff, st_memory, conv_lstm_c,
                                   MIMB_oc_weight, MIMB_ct_weight,
                                   MIMN_oc_weight, MIMN_ct_weight)
        gt_ims = torch.tensor(ims[:, 1:], device=device)

        # optical flow loss 적용
        gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
            gen_images, gt_ims, args.img_channel)
        optical_loss = DOFLoss.calc_optical_flow_loss(gen_diff, gt_diff,
                                                      args.device)
        MSE_loss = F.mse_loss(gen_images, gt_ims)

        # 역전파 적용
        optimizer.zero_grad()
        # loss = 0.8 * MSE_loss + 0.2 * optical_loss
        loss = MSE_loss
        loss.backward()
        optimizer.step()

        # 출력용 loss 저장
        loss_print = loss.detach_()
        flag = 1

        # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
        del gen_images
        detachVariables(hidden_state, cell_state, hidden_state_diff,
                        cell_state_diff, st_memory, conv_lstm_c,
                        MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                        MIMN_oc_weight)

        ims_reverse = None
        if args.reverse_img:
            ims_reverse = ims[:, :, :, ::-1]
            ims_tensor = torch.tensor(ims_reverse.copy(), device=device)
            gen_images = model.forward(ims_tensor, real_input_flag,
                                       hidden_state, cell_state,
                                       hidden_state_diff, cell_state_diff,
                                       st_memory, conv_lstm_c, MIMB_oc_weight,
                                       MIMB_ct_weight, MIMN_oc_weight,
                                       MIMN_ct_weight)
            gt_ims = torch.tensor(ims_reverse[:, 1:].copy(), device=device)

            gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
                gen_images, gt_ims, args.img_channel)
            optical_loss = DOFLoss.calc_optical_flow_loss(
                gen_diff, gt_diff, args.device)
            MSE_loss = F.mse_loss(gen_images, gt_ims)

            optimizer.zero_grad()
            # loss = 0.8 * MSE_loss + 0.2 * optical_loss
            loss = MSE_loss
            loss.backward()
            optimizer.step()

            loss_print += loss.detach_()
            flag += 1

            # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
            del gen_images
            detachVariables(hidden_state, cell_state, hidden_state_diff,
                            cell_state_diff, st_memory, conv_lstm_c,
                            MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                            MIMN_oc_weight)

        if args.reverse_input:
            ims_rev = ims[:, ::-1]
            ims_tensor = torch.tensor(ims_rev.copy(), device=device)
            gen_images = model.forward(ims_tensor, real_input_flag,
                                       hidden_state, cell_state,
                                       hidden_state_diff, cell_state_diff,
                                       st_memory, conv_lstm_c, MIMB_oc_weight,
                                       MIMB_ct_weight, MIMN_oc_weight,
                                       MIMN_ct_weight)
            gt_ims = torch.tensor(ims_rev[:, 1:].copy(), device=device)

            gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
                gen_images, gt_ims, args.img_channel)
            optical_loss = DOFLoss.calc_optical_flow_loss(
                gen_diff, gt_diff, args.device)
            MSE_loss = F.mse_loss(gen_images, gt_ims)

            optimizer.zero_grad()
            # loss = 0.8 * MSE_loss + 0.2 * optical_loss
            loss = MSE_loss
            loss.backward()
            optimizer.step()

            loss_print += loss.detach_()
            flag += 1

            # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
            del gen_images
            detachVariables(hidden_state, cell_state, hidden_state_diff,
                            cell_state_diff, st_memory, conv_lstm_c,
                            MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                            MIMN_oc_weight)

            if args.reverse_img:
                ims_rev = ims_reverse[:, ::-1]
                ims_tensor = torch.tensor(ims_rev.copy(), device=device)
                gen_images = model.forward(ims_tensor, real_input_flag,
                                           hidden_state, cell_state,
                                           hidden_state_diff, cell_state_diff,
                                           st_memory, conv_lstm_c,
                                           MIMB_oc_weight, MIMB_ct_weight,
                                           MIMN_oc_weight, MIMN_ct_weight)
                gt_ims = torch.tensor(ims_rev[:, 1:].copy(), device=device)

                gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
                    gen_images, gt_ims, args.img_channel)
                optical_loss = DOFLoss.calc_optical_flow_loss(
                    gen_diff, gt_diff, args.device)
                MSE_loss = F.mse_loss(gen_images, gt_ims)

                optimizer.zero_grad()
                # loss = 0.8 * MSE_loss + 0.2 * optical_loss
                loss = MSE_loss
                loss.backward()
                optimizer.step()

                loss_print += loss.detach_()
                flag += 1

                # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
                del gen_images
                detachVariables(hidden_state, cell_state, hidden_state_diff,
                                cell_state_diff, st_memory, conv_lstm_c,
                                MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                                MIMN_oc_weight)

        # 전방전파 한 만큼 loss 나눠줌
        loss_print = loss_print.item() / flag

        # gen_diff_tensor = torch.tensor(gen_diff, device=args.device, requires_grad=True)
        # gt_diff_tensor = torch.tensor(gt_diff, device=args.device, requires_grad=True)
        #
        # # optical flow loss 벡터 구하는 식
        # diff = gt_diff_tensor - gen_diff_tensor
        # diff = torch.pow(diff, 2)
        # squared_distance = diff[0] + diff[1]
        # distance = torch.sqrt(squared_distance)
        # distance_sum = torch.mean(distance)

        # DOF_Mloss = F.mse_loss(gen_diff_tensor[0], gt_diff_tensor[0])
        # DOF_Dloss = F.mse_loss(gen_diff_tensor[1], gt_diff_tensor[1])

        # 얘 MSE로 하던가 Norm2 마할라노비스 등등으로 loss 구한다음에 MSE_loss 랑 더해주고 역전파 시키기
        # loss = 0.7 * MSE_loss + 0.25 * DOF_Mloss + 0.25 * DOF_Dloss
        # loss = trainer.trainer(model, ims, real_input_flag, args, itr, ims_reverse, device, optimizer, MSELoss)

        if itr % args.snapshot_interval == 0:
            # 모델 세이브 할때 detachVariable에 들어가는 애들 다 바꿔줘야 함
            saveVariables(args, hidden_state, cell_state, hidden_state_diff,
                          cell_state_diff, st_memory, conv_lstm_c,
                          MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                          MIMN_oc_weight, itr)
            model.save(itr)

        if itr % args.test_interval == 0:
            trainer.test(model, test_input_handle, args, itr, hidden_state,
                         cell_state, hidden_state_diff, cell_state_diff,
                         st_memory, conv_lstm_c, MIMB_oc_weight,
                         MIMB_ct_weight, MIMN_oc_weight, MIMN_ct_weight)

        if itr % args.display_interval == 0:
            print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                  'itr: ' + str(itr))
            print('training loss: ' + str(loss_print))

        train_input_handle.next()
Пример #10
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, FLAGS.img_width, FLAGS.seq_length)

    print("Initializing models")
    model = Model()
    lr = FLAGS.lr

    # Prepare tensorboard logging
    logger = Logger(os.path.join(FLAGS.gen_frm_dir, 'board'), model.sess)
    logger.define_item("loss", Logger.Scalar, ())
    logger.define_item("lr", Logger.Scalar, ())
    logger.define_item("mse", Logger.Scalar, ())
    logger.define_item("psnr", Logger.Scalar, ())
    logger.define_item("fmae", Logger.Scalar, ())
    logger.define_item("ssim", Logger.Scalar, ())
    logger.define_item("sharp", Logger.Scalar, ())
    logger.define_item(
        "image",
        Logger.Image,
        (1, 2 * FLAGS.img_width, FLAGS.img_width, FLAGS.img_channel),
        dtype='uint8')

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        logger.add('lr', lr, itr)
        cost = model.train(ims, lr)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2
        logger.add('loss', cost, itr)

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat)

                # concat outputs of different gpus along batch
                # img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen[:, np.newaxis, :, :, :], FLAGS.patch_size)
                # MSE per frame
                for i in range(1):
                    x = test_ims[:, -1, :, :, 0]
                    gx = img_gen[:, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        score, _ = compare_ssim(pred_frm[b],
                                                real_frm[b],
                                                full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id == 1:
                    sel = np.random.randint(FLAGS.batch_size)
                    img_seq_pd = img_gen[sel]
                    img_seq_gt = test_ims[sel, -1]
                    h, w = img_gen.shape[1:3]
                    out_img = np.zeros((1, h * 2, w * 1, FLAGS.img_channel),
                                       dtype='uint8')
                    for i, img_seq in enumerate([img_seq_gt, img_seq_pd]):
                        img = img_seq
                        img = np.maximum(img, 0)
                        img = np.uint8(img * 10)
                        img = np.minimum(img, 255)
                        out_img[0, (i * h):(i * h + h), :] = img
                    logger.add("image", out_img, itr)

                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            logger.add('mse', avg_mse, itr)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            ssim = np.asarray(ssim, dtype=np.float32) / \
                (FLAGS.batch_size * batch_id)
            sharp = np.asarray(sharp, dtype=np.float32) / \
                (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            logger.add('psnr', np.mean(psnr), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            logger.add('fmae', np.mean(fmae), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            logger.add('ssim', np.mean(ssim), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            logger.add('sharp', np.mean(sharp), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Пример #11
0
args = parser.parse_args()
args.dataset_name = "mnist"
args.train_data_paths = 'data/moving-mnist-example/moving-mnist-train.npz'
args.valid_data_paths = 'data/moving-mnist-example/moving-mnist-valid.npz'
args.save_dir = 'checkpoints/mnist_predrnn_pp'
args.img_width = 64
args.batch_size = 8
args.patch_size = 4  #1
args.seq_length = 19
args.num_hidden = [128, 64, 64, 64, 16]
args.num_layers = len(args.num_hidden)
args.lr = 0.00001

##### load the train data
train_input_handle, test_input_handle = datasets_factory.data_provider(
    args.dataset_name, args.train_data_paths, args.valid_data_paths,
    args.batch_size, args.img_width)

#  tmp = np.load(args.train_data_paths)
#  tmp = tmp['input_raw_data']
#  print(tmp.shape)
#  print(tmp[0].shape)
#print(type(train_input_handle), type(test_input_handle))

model = CausalLSTMStack(3, 2, args.num_hidden)  #filter_size, num_dims
decoder = torch.nn.Conv2d(16, 16, 1, 1)
# tmp = np.random.rand(8, 20, 16, 16, 16)

###  run iters

model.cuda()
Пример #12
0
def main(argv=None):

    if ~tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.MakeDirs(FLAGS.save_dir)

    print('start training !',time.strftime('%Y-%m-%d %H:%M:%S\n\n\n',time.localtime(time.time())))
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size * FLAGS.n_gpu,
        FLAGS.joint_dim,
        FLAGS.joint_num,
        FLAGS.seq_length,
        is_training=True)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr
    train_time=0
    test_time_all=0

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        start_time = time.time()
        ims = train_input_handle.get_batch()
        ims_list = np.split(ims, FLAGS.n_gpu)        
        cost = model.train(ims_list, lr)

        if FLAGS.reverse_input:
            ims_rev = np.split(ims[:, ::-1], FLAGS.n_gpu)
            cost += model.train(ims_rev, lr)
            cost = cost/2
        end_time = time.time()
        t=end_time-start_time
        train_time += t

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr)+' lr: '+str(lr)+' training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            test_time=0

            print('train time:'+ str(train_time))
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            
            mse_per_frame = []
            mae_per_frame = []
            while(test_input_handle.no_batch_left() == False):
                curr_test_time_start = time.time()
                test_ims = test_input_handle.get_batch()
                test_dat = np.split(test_ims, FLAGS.n_gpu)
                img_gen = model.test(test_dat)

                curr_test_time = time.time() - curr_test_time_start
                test_time += curr_test_time

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                absoult_err = np.squeeze(np.abs(test_ims[:, FLAGS.input_length:] - img_gen))
                absoult_err_t = np.transpose(absoult_err, [1, 0, 2, 3]).reshape([(FLAGS.seq_length-FLAGS.input_length), -1])
                mae = np.mean(absoult_err_t, axis=-1, keepdims=True)*FLAGS.joint_num*FLAGS.joint_dim
                mse = np.mean(absoult_err_t**2, axis=-1, keepdims=True)*FLAGS.joint_num*FLAGS.joint_dim
                mae_per_frame.append(mae)
                mse_per_frame.append(mse)

                test_input_handle.next()

            test_time_all += test_time
            print('current test time:'+str(test_time))
            print('all test time: '+str(test_time_all))
            mae_per_frame = np.concatenate(mae_per_frame).reshape([-1, (FLAGS.seq_length-FLAGS.input_length)]).mean(axis=0)
            mse_per_frame = np.concatenate(mse_per_frame).reshape([-1, (FLAGS.seq_length-FLAGS.input_length)]).mean(axis=0)

            print('average mse per frame: ', mse_per_frame.mean())
            _ = list(map(print, mse_per_frame))
            print('average mae per frame: ', mae_per_frame.mean())
            _ = list(map(print, mae_per_frame))

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
            print('model saving done! ', time.strftime('%Y-%m-%d %H:%M:%S\n\n\n',time.localtime(time.time())))

        train_input_handle.next()