def save_hko_gif(im_dat, save_path): """Save the HKO images to gif Parameters ---------- im_dat : np.ndarray Shape: (seqlen, H, W) save_path : str Returns ------- """ assert im_dat.ndim == 3 save_gif(im_dat, fname=save_path) return
def analysis(args): cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS = True assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1 base_dir = args.save_dir logging_config(folder=base_dir, name="testing") save_movingmnist_cfg(base_dir) mnist_iter = MovingMNISTAdvancedIterator( distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM, initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER, cfg.MOVINGMNIST.VELOCITY_UPPER), rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER, cfg.MOVINGMNIST.ROTATION_UPPER), scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER, cfg.MOVINGMNIST.SCALE_VARIATION_UPPER), illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER, cfg.MOVINGMNIST.ILLUMINATION_UPPER)) mnist_rnn = MovingMNISTFactory(batch_size=1, in_seq_len=cfg.MODEL.IN_LEN, out_seq_len=cfg.MODEL.OUT_LEN) encoder_net, forecaster_net, loss_net = \ encoder_forecaster_build_networks( factory=mnist_rnn, context=args.ctx) encoder_net.summary() forecaster_net.summary() loss_net.summary() states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0]) states.reset_all() # Begin to load the model if load_dir is not empty assert len(cfg.MODEL.LOAD_DIR) > 0 load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER, encoder_net=encoder_net, forecaster_net=forecaster_net) for iter_id in range(1): frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN) data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0 target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:( cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...], ctx=args.ctx[0]) / 255.0 pred_nd = mnist_get_prediction(data_nd=data_nd, states=states, encoder_net=encoder_net, forecaster_net=forecaster_net) save_gif(pred_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "pred.gif")) save_gif(data_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "in.gif")) save_gif(target_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "gt.gif"))
def train(args): base_dir = get_base_dir(args) ### Get modules generator_net, loss_net = construct_modules(args) ### Prepare data mnist_iter = MovingMNISTAdvancedIterator( distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM, initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER, cfg.MOVINGMNIST.VELOCITY_UPPER), rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER, cfg.MOVINGMNIST.ROTATION_UPPER), scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER, cfg.MOVINGMNIST.SCALE_VARIATION_UPPER), illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER, cfg.MOVINGMNIST.ILLUMINATION_UPPER)) for i in range(cfg.MODEL.TRAIN.MAX_ITER): seq, flow = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN) in_seq = seq[:cfg.MOVINGMNIST.IN_LEN, ...] gt_seq = seq[cfg.MOVINGMNIST.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...] # Transform data to NCDHW shape needed for 3D Convolution encoder and normalize context_nd = mx.nd.array(in_seq) / 255.0 gt_nd = mx.nd.array(gt_seq) / 255.0 context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4)) gt_nd = mx.nd.transpose(gt_nd, axes=(1, 2, 0, 3, 4)) # Train a step pred_nd, avg_l2, avg_real_mse, generator_grad_norm =\ train_step(generator_net, loss_net, context_nd, gt_nd) # Logging logging.info( ("Iter:{}, L2 Loss:{}, MSE Error:{}, Generator Grad Norm:{}" ).format(i, avg_l2, avg_real_mse, generator_grad_norm)) logging.info("Iter:%d" % i) if (i + 1) % 100 == 0: save_gif(context_nd.asnumpy()[0, 0, :, :, :], os.path.join(base_dir, "input.gif")) save_gif(gt_nd.asnumpy()[0, 0, :, :, :], os.path.join(base_dir, "gt.gif")) save_gif(pred_nd.asnumpy()[0, 0, :, :, :], os.path.join(base_dir, "pred.gif")) if cfg.MODEL.SAVE_ITER > 0 and (i + 1) % cfg.MODEL.SAVE_ITER == 0: generator_net.save_checkpoint(prefix=os.path.join( base_dir, "generator"), epoch=i)
is_train=True) outputs = net.get_outputs() net.backward() norm_val = get_global_norm_val(net) # norm_clipping(params_grad=[grad[0] for grad in net._exec_group.grad_arrays], # threshold=100, batch_size=batch_size) logging.info( "Iter:%d, Error:%f, Norm:%f" % (i, outputs[0].asnumpy().sum() / batch_size / 64 / 64, norm_val)) for k, v, grad_v in zip(net._param_names, net._exec_group.param_arrays, net._exec_group.grad_arrays): if "bn" not in k: print k, v[0].shape, nd.norm(v[0]).asnumpy(), nd.norm( grad_v[0] / batch_size).asnumpy() net.update() if (i + 1) % 100 == 0: test_net.forward(data_batch=mx.io.DataBatch( data=[mx.nd.array(in_seq) / 255.0], label=None), is_train=False) test_prediction = test_net.get_outputs()[0].asnumpy() logging.info( "Iter:%d, Test Error:%f" % (i, -cross_entropy_npy(gt_seq / 255.0, test_prediction).sum() / out_seq_len / batch_size)) save_gif(test_prediction[:, 0, 0, :, :], "test.gif") if (i + 1) % 2000 == 0: net.save_checkpoint(prefix=os.path.join( base_dir, "%s_%s" % (conv_rnn_typ, transform_typ)), epoch=i)
batch_size = 1 if args.mode == 'test': seqlen = 100 elif args.mode == 'save': if args.path: fname = args.path else: fname = "params.npz" print("Generating {} sequences of length {}. Saving to {}.".format( args.sequences, args.length, fname)) seqlen = args.length mnist_generator.save(seqlen=seqlen, num_samples=args.sequences, file=fname) elif args.mode == 'load': if args.path: fname = args.path else: fname = "params.npz" num_sequences, seqlen = mnist_generator.load(file=fname) print("Loaded {} sequences of length {}. Saving to {}.".format( num_sequences, seqlen, fname)) seq, _ = mnist_generator.sample(batch_size=batch_size, seqlen=seqlen) print(seq.sum()) save_gif(seq[:, 0, 0, :, :].astype(np.float32) / 255.0, "test.gif")
def train_mnist(encoder_forecaster, optimizer, criterion, lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name, base_dir, probToPixel=None): IN_LEN = cfg.MODEL.IN_LEN OUT_LEN = cfg.MODEL.OUT_LEN evaluater = HKOEvaluation(seq_len=OUT_LEN, use_central=False) train_loss = 0.0 save_dir = osp.join(base_dir, folder_name) if not os.path.exists(save_dir): os.mkdir(save_dir) model_save_dir = osp.join(save_dir, 'models') log_dir = osp.join(save_dir, 'logs') all_scalars_file_name = osp.join(save_dir, "all_scalars.json") # pkl_save_dir = osp.join(save_dir, 'pkl') if osp.exists(all_scalars_file_name): os.remove(all_scalars_file_name) if osp.exists(log_dir): shutil.rmtree(log_dir) if osp.exists(model_save_dir): shutil.rmtree(model_save_dir) os.mkdir(model_save_dir) writer = SummaryWriter(log_dir) mnist_iter = MovingMNISTAdvancedIterator( distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM, initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER, cfg.MOVINGMNIST.VELOCITY_UPPER), rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER, cfg.MOVINGMNIST.ROTATION_UPPER), scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER, cfg.MOVINGMNIST.SCALE_VARIATION_UPPER), illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER, cfg.MOVINGMNIST.ILLUMINATION_UPPER)) itera = 0 while itera < max_iterations: frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN) train_data = torch.from_numpy( np.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...])).to( cfg.GLOBAL.DEVICE) / 255.0 train_label = torch.from_numpy(frame_dat[cfg.MODEL.IN_LEN:( cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...]).to( cfg.GLOBAL.DEVICE) / 255.0 encoder_forecaster.train() optimizer.zero_grad() output = encoder_forecaster(train_data) mask = torch.from_numpy(np.ones(train_label.size()).astype(int)).to( cfg.GLOBAL.DEVICE) loss = criterion(output, train_label, mask) loss.backward() torch.nn.utils.clip_grad_value_(encoder_forecaster.parameters(), clip_value=50.0) optimizer.step() lr_scheduler.step() train_loss += loss.item() train_label_numpy = train_label.cpu().numpy() if probToPixel is None: output_numpy = np.clip(output.detach().cpu().numpy(), 0.0, 1.0) else: # if classification, output: S*B*C*H*W output_numpy = probToPixel(output.detach().cpu().numpy(), train_label, mask, lr_scheduler.get_lr()[0]) evaluater.update(train_label_numpy, output_numpy, mask.cpu().numpy()) if (itera + 1) % test_iteration_interval == 0: with torch.no_grad(): encoder_forecaster.eval() overall_mse = 0 for iter_id in range(10): valid_frame, _ = mnist_iter.sample( batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN, random=False) valid_data = torch.from_numpy( np.array(valid_frame[0:cfg.MOVINGMNIST.IN_LEN, ...])).to( cfg.GLOBAL.DEVICE) / 255.0 valid_label = torch.from_numpy( valid_frame[cfg.MODEL.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...]).to(cfg.GLOBAL.DEVICE) / 255.0 output = encoder_forecaster(valid_data) overall_mse += torch.mean((valid_label - output)**2) avg_mse = overall_mse / 10 with open(os.path.join(base_dir, 'result.txt'), 'a') as f: f.write(str(avg_mse) + '\n') print(base_dir, avg_mse) gif_dir = os.path.join(base_dir, "gif") if not os.path.exists(gif_dir): os.mkdir(gif_dir) save_gif(output.detach().cpu().numpy()[:, 0, 0, :, :], os.path.join(gif_dir, "pred-{}.gif".format(itera))) save_gif(train_data.detach().cpu().numpy()[:, 0, 0, :, :], os.path.join(gif_dir, "in-{}.gif".format(itera))) save_gif(train_label.detach().cpu().numpy()[:, 0, 0, :, :], os.path.join(gif_dir, "gt-{}.gif".format(itera))) if (itera + 1) % test_and_save_checkpoint_iterations == 0: torch.save( encoder_forecaster.state_dict(), osp.join(model_save_dir, 'encoder_forecaster_{}.pth'.format(itera))) itera += 1 writer.close()
def train(args): assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1 base_dir = args.save_dir logging_config(folder=base_dir, name="training") save_movingmnist_cfg(base_dir) mnist_iter = MovingMNISTAdvancedIterator( distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM, initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER, cfg.MOVINGMNIST.VELOCITY_UPPER), rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER, cfg.MOVINGMNIST.ROTATION_UPPER), scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER, cfg.MOVINGMNIST.SCALE_VARIATION_UPPER), illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER, cfg.MOVINGMNIST.ILLUMINATION_UPPER)) mnist_rnn = MovingMNISTFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx), in_seq_len=cfg.MODEL.IN_LEN, out_seq_len=cfg.MODEL.OUT_LEN) encoder_net, forecaster_net, loss_net = \ encoder_forecaster_build_networks( factory=mnist_rnn, context=args.ctx) t_encoder_net, t_forecaster_net, t_loss_net = \ encoder_forecaster_build_networks( factory=mnist_rnn, context=args.ctx[0], shared_encoder_net=encoder_net, shared_forecaster_net=forecaster_net, shared_loss_net=loss_net, for_finetune=True) encoder_net.summary() forecaster_net.summary() loss_net.summary() # Resume last checkpoint if args.resume: encoder_net.load(prefix=os.path.join(base_dir, 'encoder_net'), epoch=latest_iter_id(base_dir), load_optimizer_states=True, data_names=[ 'data', 'ebrnn1_begin_state_h', 'ebrnn2_begin_state_h', 'ebrnn3_begin_state_h' ], label_names=[]) # change it next time forecaster_net.load(prefix=os.path.join(base_dir, 'forecaster_net'), epoch=latest_iter_id(base_dir), load_optimizer_states=True, data_names=[ 'fbrnn1_begin_state_h', 'fbrnn2_begin_state_h', 'fbrnn3_begin_state_h' ], label_names=[]) # change it next time # Begin to load the model if load_dir is not empty if len(cfg.MODEL.LOAD_DIR) > 0: load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER, encoder_net=encoder_net, forecaster_net=forecaster_net) states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0]) states.reset_all() for info in mnist_rnn.init_encoder_state_info: assert info["__layout__"].find( 'N') == 0, "Layout=%s is not supported!" % info["__layout__"] iter_id = 0 while iter_id < cfg.MODEL.TRAIN.MAX_ITER: frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN) data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0 target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:( cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...], ctx=args.ctx[0]) / 255.0 train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, encoder_net=encoder_net, forecaster_net=forecaster_net, loss_net=loss_net, init_states=states, data_nd=data_nd, gt_nd=target_nd, mask_nd=None, iter_id=iter_id) if (iter_id + 1) % 100 == 0: new_frame_dat, _ = mnist_iter.sample( batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN) data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0 target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:( cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...], ctx=args.ctx[0]) / 255.0 pred_nd = mnist_get_prediction(data_nd=data_nd, states=states, encoder_net=encoder_net, forecaster_net=forecaster_net) save_gif(pred_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "pred.gif")) save_gif(data_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "in.gif")) save_gif(target_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "gt.gif")) if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0: encoder_net.save_checkpoint(prefix=os.path.join( base_dir, "encoder_net"), epoch=iter_id, save_optimizer_states=True) forecaster_net.save_checkpoint(prefix=os.path.join( base_dir, "forecaster_net"), epoch=iter_id, save_optimizer_states=True) iter_id += 1