def run_benchmark(hko_factory, context, encoder_net, forecaster_net, loss_net=None, sample_num=1, finetune=False, mode="fixed", save_dir="hko7_rnn", pd_path=cfg.HKO_PD.RAINY_TEST): """Run the HKO7 Benchmark given the training sequences Parameters ---------- hko_factory : context : mx.ctx encoder_net : MyModule forecaster_net : MyModule loss_net : MyModule sample_num : int finetune : bool mode : str save_dir : str pd_path : str Returns ------- """ logging.info("Begin Evaluation, mode=%s, finetune=%s, sample_num=%d," " results will be saved to %s" % (mode, str(finetune), sample_num, save_dir)) if finetune: logging.info(str(cfg.MODEL.TEST.ONLINE)) env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode) states = EncoderForecasterStates(factory=hko_factory, ctx=context) stored_data = [] if not cfg.MODEL.TEST.DISABLE_TBPTT: stored_states = [] stored_prediction = [] counter = 0 finetune_iter = 0 while not env.done: if finetune: if len(stored_data) >= 5: data_in = stored_data[0] data_gt = np.concatenate(stored_data[1:], axis=0) gt_mask = precompute_mask(data_gt) init_states = EncoderForecasterStates(factory=hko_factory, ctx=context) if not cfg.MODEL.TEST.DISABLE_TBPTT: init_states.update(states_nd=[ nd.array(ele, ctx=context) for ele in stored_states[0] ]) weights = get_balancing_weights_numba( data=data_gt, mask=gt_mask, base_balancing_weights=cfg.HKO.EVALUATION. BALANCING_WEIGHTS, thresholds=env._all_eval._thresholds) weighted_mse = (weights * np.square(stored_prediction[0] - data_gt)).sum( axis=(2, 3, 4)) mean_weighted_mse = weighted_mse.mean() print("mean_weighted_mse = %g" % mean_weighted_mse) if mean_weighted_mse > cfg.MODEL.TEST.ONLINE.FINETUNE_MIN_MSE: _, loss_dict =\ train_step(batch_size=1, encoder_net=encoder_net, forecaster_net=forecaster_net, loss_net=loss_net, init_states=init_states, data_nd=nd.array(data_in, ctx=context), gt_nd=nd.array(data_gt, ctx=context), mask_nd=nd.array(gt_mask, ctx=context), iter_id=finetune_iter) finetune_iter += 1 stored_data = stored_data[1:] stored_prediction = stored_prediction[1:] if not cfg.MODEL.TEST.DISABLE_TBPTT: stored_states = stored_states[1:] if mode == "fixed" or cfg.MODEL.TEST.DISABLE_TBPTT: states.reset_all() in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\ env.get_observation(batch_size=1) if finetune: if begin_new_episode: stored_data = [in_frame_dat] stored_prediction = [] if not cfg.MODEL.TEST.DISABLE_TBPTT: stored_states = [[ ele.asnumpy() for ele in states.get_encoder_states() ]] else: stored_data.append(in_frame_dat) if not cfg.MODEL.TEST.DISABLE_TBPTT: stored_states.append( [ele.asnumpy() for ele in states.get_encoder_states()]) in_frame_nd = nd.array(in_frame_dat, ctx=context) encoder_net.forward( is_train=False, data_batch=mx.io.DataBatch(data=[in_frame_nd] + states.get_encoder_states())) outputs = encoder_net.get_outputs() states.update(states_nd=outputs) if need_upload_prediction: counter += 1 if cfg.MODEL.OUT_TYPE == "direct": forecaster_net.forward(is_train=False, data_batch=mx.io.DataBatch( data=states.get_forecaster_state())) pred_nd = forecaster_net.get_outputs()[0] else: forecaster_net.forward( is_train=False, data_batch=mx.io.DataBatch( data=states.get_forecaster_state() + [in_frame_nd[in_frame_nd.shape[0] - 1]])) pred_nd = forecaster_net.get_outputs()[0] flow_nd = forecaster_net.get_outputs()[1] pred_nd = nd.clip(pred_nd, a_min=0, a_max=1) env.upload_prediction(prediction=pred_nd.asnumpy()) if finetune: stored_prediction.append(pred_nd.asnumpy()) env.save_eval()
def hko_benchmark(ctx, generator_net, loss_net, sample_num, finetune=False, mode="fixed", save_dir="hko7_rnn", pd_path=cfg.HKO_PD.RAINY_TEST): """Run the HKO7 Benchmark given the training sequences Args: ctx generator_net sample_num save_dir pd_path """ logging.info("Begin Evaluation, sample_num=%d," " results will be saved to %s" % (sample_num, save_dir)) if finetune: logging.info(str(cfg.MODEL.TEST.ONLINE)) env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode) if finetune: assert (mode == "online") data_buffer = [] stored_prediction = [] finetune_iter = 0 context_nd = None i = 0 while not env.done: logging.info("Iter {} of evaluation.".format(i)) i += 1 if finetune: if len(data_buffer) >= 5: context_np = data_buffer[0] # HKO.BENCHMARK.IN_LEN frames gt_np = np.concatenate(data_buffer[1:], axis=0) gt_np = gt_np[:cfg.HKO.BENCHMARK.OUT_LEN] mask_np = precompute_mask(gt_np) weights = get_balancing_weights_numba( data=gt_np, mask=mask_np, base_balancing_weights=cfg.HKO.EVALUATION. BALANCING_WEIGHTS, thresholds=env._all_eval._thresholds) weighted_mse = (weights * np.square(stored_prediction[0] - gt_np)).sum( axis=(2, 3, 4)) mean_weighted_mse = weighted_mse.mean() print("mean_weighted_mse = %g" % mean_weighted_mse) if mean_weighted_mse > cfg.MODEL.TEST.ONLINE.FINETUNE_MIN_MSE: context_nd = mx.nd.array(context_np, ctx=ctx) context_nd = mx.nd.transpose( context_nd, axes=(1, 2, 0, 3, 4)) gt_nd = mx.nd.array(gt_np, ctx=ctx) gt_nd = mx.nd.transpose(gt_nd, axes=(1, 2, 0, 3, 4)) mask_nd = mx.nd.array(mask_np, ctx=ctx) mask_nd = mx.nd.transpose(mask_nd, axes=(1, 2, 0, 3, 4)) train_step( generator_net=generator_net, loss_net=loss_net, context_nd=context_nd, gt_nd=gt_nd, mask_nd=mask_nd) finetune_iter += 1 del data_buffer[0] del stored_prediction[0] if mode == "online": context_np, in_datetime_clips, out_datetime_clips,\ begin_new_episode, need_upload_prediction = env.get_observation( batch_size=1) context_np = np.repeat( context_np, cfg.MODEL.TRAIN.BATCH_SIZE, axis=1) orig_size = 1 elif mode == "fixed": context_np, in_datetime_clips, out_datetime_clips,\ begin_new_episode, need_upload_prediction = env.get_observation( batch_size=cfg.MODEL.TRAIN.BATCH_SIZE) context_nd = mx.nd.array(context_np, ctx=ctx) context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4)) # Pad context_nd up to batch size if needed orig_size = context_nd.shape[0] while context_nd.shape[0] < cfg.MODEL.TRAIN.BATCH_SIZE: context_nd = mx.nd.concat( context_nd, context_nd[0:1], num_args=2, dim=0) else: raise NotImplementedError if finetune: if begin_new_episode: data_buffer = [context_np] prediction_buffer = [] else: data_buffer.append(context_np) if mode != "fixed": context_nd = mx.nd.array(context_np, ctx=ctx) context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4)) generator_net.forward( is_train=False, data_batch=mx.io.DataBatch(data=[context_nd])) if need_upload_prediction: generator_outputs = dict( zip(generator_net.output_names, generator_net.get_outputs())) pred_nd = generator_outputs["pred_output"] pred_nd = pred_nd[0:orig_size] pred_nd = mx.nd.clip(pred_nd, a_min=0, a_max=1) pred_nd = mx.nd.transpose(pred_nd, axes=(2, 0, 1, 3, 4)) env.upload_prediction(prediction=pred_nd.asnumpy()) if finetune: stored_prediction.append(pred_nd.asnumpy()) env.save_eval()
def run(pd_path=cfg.HKO_PD.RAINY_TEST, mode="fixed", interp_type="bilinear", nonlinear_transform=True): transformer = NonLinearRoverTransform() flow_factory = VarFlowFactory(max_level=6, start_level=0, n1=2, n2=2, rho=1.5, alpha=2000, sigma=4.5) assert interp_type == "bilinear", "Nearest interpolation is implemented in CPU and is too slow." \ " We only support bilinear interpolation for rover." if nonlinear_transform: base_dir = os.path.join('hko7_benchmark', 'rover-nonlinear') else: base_dir = os.path.join('hko7_benchmark', 'rover-linear') logging_config(base_dir) batch_size = 1 env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=base_dir, mode=mode) counter = 0 while not env.done: in_frame_dat, in_datetime_clips, out_datetime_clips, \ begin_new_episode, need_upload_prediction = \ env.get_observation(batch_size=batch_size) if need_upload_prediction: counter += 1 prediction = np.zeros(shape=(cfg.HKO.BENCHMARK.OUT_LEN, ) + in_frame_dat.shape[1:], dtype=np.float32) I1 = in_frame_dat[-2, :, 0, :, :] I2 = in_frame_dat[-1, :, 0, :, :] mask_I1 = precompute_mask(I1) mask_I2 = precompute_mask(I2) I1 = I1 * mask_I1 I2 = I2 * mask_I2 if nonlinear_transform: I1 = transformer.transform(I1) I2 = transformer.transform(I2) flow = flow_factory.batch_calc_flow(I1=I1, I2=I2) if interp_type == "bilinear": init_im = nd.array(I2.reshape( (I2.shape[0], 1, I2.shape[1], I2.shape[2])), ctx=mx.gpu()) nd_flow = nd.array(np.concatenate( (flow[:, :1, :, :], -flow[:, 1:, :, :]), axis=1), ctx=mx.gpu()) nd_pred_im = nd.zeros(shape=prediction.shape) for i in range(cfg.HKO.BENCHMARK.OUT_LEN): new_im = nd_advection(init_im, flow=nd_flow) nd_pred_im[i][:] = new_im init_im[:] = new_im prediction = nd_pred_im.asnumpy() elif interp_type == "nearest": init_im = I2.reshape( (I2.shape[0], 1, I2.shape[1], I2.shape[2])) for i in range(cfg.HKO.BENCHMARK.OUT_LEN): new_im = nearest_neighbor_advection(init_im, flow) prediction[i, ...] = new_im init_im = new_im if nonlinear_transform: prediction = transformer.rev_transform(prediction) env.upload_prediction(prediction=prediction) if counter % 10 == 0: save_hko_gif(in_frame_dat[:, 0, 0, :, :], save_path=os.path.join(base_dir, 'in.gif')) save_hko_gif(prediction[:, 0, 0, :, :], save_path=os.path.join(base_dir, 'pred.gif')) env.print_stat_readable() # import matplotlib.pyplot as plt # Q = plt.quiver(flow[1, 0, ::10, ::10], flow[1, 1, ::10, ::10]) # plt.gca().invert_yaxis() # plt.show() # ch = raw_input() env.save_eval()