def save_hko_gif(im_dat, save_path):
    """Save the HKO images to gif

    Parameters
    ----------
    im_dat : np.ndarray
        Shape: (seqlen, H, W)
    save_path : str
    Returns
    -------
    """
    assert im_dat.ndim == 3
    save_gif(im_dat, fname=save_path)
    return
Ejemplo n.º 2
0
def analysis(args):
    cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS = True
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_movingmnist_cfg(base_dir)
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))
    mnist_rnn = MovingMNISTFactory(batch_size=1,
                                   in_seq_len=cfg.MODEL.IN_LEN,
                                   out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
    states.reset_all()
    # Begin to load the model if load_dir is not empty
    assert len(cfg.MODEL.LOAD_DIR) > 0
    load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                   load_iter=cfg.MODEL.LOAD_ITER,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net)
    for iter_id in range(1):
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN)
        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        pred_nd = mnist_get_prediction(data_nd=data_nd,
                                       states=states,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
        save_gif(pred_nd.asnumpy()[:, 0, 0, :, :],
                 os.path.join(base_dir, "pred.gif"))
        save_gif(data_nd.asnumpy()[:, 0, 0, :, :],
                 os.path.join(base_dir, "in.gif"))
        save_gif(target_nd.asnumpy()[:, 0, 0, :, :],
                 os.path.join(base_dir, "gt.gif"))
Ejemplo n.º 3
0
def train(args):
    base_dir = get_base_dir(args)

    ### Get modules
    generator_net, loss_net = construct_modules(args)

    ### Prepare data
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))

    for i in range(cfg.MODEL.TRAIN.MAX_ITER):
        seq, flow = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                      seqlen=cfg.MOVINGMNIST.IN_LEN +
                                      cfg.MOVINGMNIST.OUT_LEN)
        in_seq = seq[:cfg.MOVINGMNIST.IN_LEN, ...]
        gt_seq = seq[cfg.MOVINGMNIST.IN_LEN:(cfg.MOVINGMNIST.IN_LEN +
                                             cfg.MOVINGMNIST.OUT_LEN), ...]

        # Transform data to NCDHW shape needed for 3D Convolution encoder and normalize
        context_nd = mx.nd.array(in_seq) / 255.0
        gt_nd = mx.nd.array(gt_seq) / 255.0
        context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4))
        gt_nd = mx.nd.transpose(gt_nd, axes=(1, 2, 0, 3, 4))

        # Train a step
        pred_nd, avg_l2, avg_real_mse, generator_grad_norm =\
            train_step(generator_net, loss_net, context_nd, gt_nd)

        # Logging
        logging.info(
            ("Iter:{}, L2 Loss:{}, MSE Error:{}, Generator Grad Norm:{}"
             ).format(i, avg_l2, avg_real_mse, generator_grad_norm))

        logging.info("Iter:%d" % i)
        if (i + 1) % 100 == 0:
            save_gif(context_nd.asnumpy()[0, 0, :, :, :],
                     os.path.join(base_dir, "input.gif"))
            save_gif(gt_nd.asnumpy()[0, 0, :, :, :],
                     os.path.join(base_dir, "gt.gif"))
            save_gif(pred_nd.asnumpy()[0, 0, :, :, :],
                     os.path.join(base_dir, "pred.gif"))
        if cfg.MODEL.SAVE_ITER > 0 and (i + 1) % cfg.MODEL.SAVE_ITER == 0:
            generator_net.save_checkpoint(prefix=os.path.join(
                base_dir, "generator"),
                                          epoch=i)
Ejemplo n.º 4
0
                is_train=True)
    outputs = net.get_outputs()
    net.backward()
    norm_val = get_global_norm_val(net)
    # norm_clipping(params_grad=[grad[0] for grad in net._exec_group.grad_arrays],
    #               threshold=100, batch_size=batch_size)
    logging.info(
        "Iter:%d, Error:%f, Norm:%f" %
        (i, outputs[0].asnumpy().sum() / batch_size / 64 / 64, norm_val))

    for k, v, grad_v in zip(net._param_names, net._exec_group.param_arrays,
                            net._exec_group.grad_arrays):
        if "bn" not in k:
            print k, v[0].shape, nd.norm(v[0]).asnumpy(), nd.norm(
                grad_v[0] / batch_size).asnumpy()
    net.update()
    if (i + 1) % 100 == 0:
        test_net.forward(data_batch=mx.io.DataBatch(
            data=[mx.nd.array(in_seq) / 255.0], label=None),
                         is_train=False)
        test_prediction = test_net.get_outputs()[0].asnumpy()
        logging.info(
            "Iter:%d, Test Error:%f" %
            (i, -cross_entropy_npy(gt_seq / 255.0, test_prediction).sum() /
             out_seq_len / batch_size))
        save_gif(test_prediction[:, 0, 0, :, :], "test.gif")
    if (i + 1) % 2000 == 0:
        net.save_checkpoint(prefix=os.path.join(
            base_dir, "%s_%s" % (conv_rnn_typ, transform_typ)),
                            epoch=i)
Ejemplo n.º 5
0
    batch_size = 1
    if args.mode == 'test':
        seqlen = 100
    elif args.mode == 'save':
        if args.path:
            fname = args.path
        else:
            fname = "params.npz"

        print("Generating {} sequences of length {}. Saving to {}.".format(
            args.sequences, args.length, fname))
        seqlen = args.length
        mnist_generator.save(seqlen=seqlen,
                             num_samples=args.sequences,
                             file=fname)
    elif args.mode == 'load':
        if args.path:
            fname = args.path
        else:
            fname = "params.npz"
        num_sequences, seqlen = mnist_generator.load(file=fname)
        print("Loaded {} sequences of length {}. Saving to {}.".format(
            num_sequences, seqlen, fname))

    seq, _ = mnist_generator.sample(batch_size=batch_size, seqlen=seqlen)

    print(seq.sum())

    save_gif(seq[:, 0, 0, :, :].astype(np.float32) / 255.0, "test.gif")
def train_mnist(encoder_forecaster,
                optimizer,
                criterion,
                lr_scheduler,
                batch_size,
                max_iterations,
                test_iteration_interval,
                test_and_save_checkpoint_iterations,
                folder_name,
                base_dir,
                probToPixel=None):
    IN_LEN = cfg.MODEL.IN_LEN
    OUT_LEN = cfg.MODEL.OUT_LEN
    evaluater = HKOEvaluation(seq_len=OUT_LEN, use_central=False)
    train_loss = 0.0
    save_dir = osp.join(base_dir, folder_name)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    model_save_dir = osp.join(save_dir, 'models')
    log_dir = osp.join(save_dir, 'logs')
    all_scalars_file_name = osp.join(save_dir, "all_scalars.json")
    # pkl_save_dir = osp.join(save_dir, 'pkl')
    if osp.exists(all_scalars_file_name):
        os.remove(all_scalars_file_name)
    if osp.exists(log_dir):
        shutil.rmtree(log_dir)
    if osp.exists(model_save_dir):
        shutil.rmtree(model_save_dir)
    os.mkdir(model_save_dir)

    writer = SummaryWriter(log_dir)
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))

    itera = 0
    while itera < max_iterations:
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN)
        train_data = torch.from_numpy(
            np.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...])).to(
                cfg.GLOBAL.DEVICE) / 255.0
        train_label = torch.from_numpy(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...]).to(
                cfg.GLOBAL.DEVICE) / 255.0
        encoder_forecaster.train()
        optimizer.zero_grad()
        output = encoder_forecaster(train_data)
        mask = torch.from_numpy(np.ones(train_label.size()).astype(int)).to(
            cfg.GLOBAL.DEVICE)
        loss = criterion(output, train_label, mask)
        loss.backward()
        torch.nn.utils.clip_grad_value_(encoder_forecaster.parameters(),
                                        clip_value=50.0)
        optimizer.step()
        lr_scheduler.step()
        train_loss += loss.item()
        train_label_numpy = train_label.cpu().numpy()
        if probToPixel is None:
            output_numpy = np.clip(output.detach().cpu().numpy(), 0.0, 1.0)
        else:
            # if classification, output: S*B*C*H*W
            output_numpy = probToPixel(output.detach().cpu().numpy(),
                                       train_label, mask,
                                       lr_scheduler.get_lr()[0])

        evaluater.update(train_label_numpy, output_numpy, mask.cpu().numpy())

        if (itera + 1) % test_iteration_interval == 0:
            with torch.no_grad():
                encoder_forecaster.eval()
                overall_mse = 0
                for iter_id in range(10):
                    valid_frame, _ = mnist_iter.sample(
                        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                        seqlen=cfg.MOVINGMNIST.IN_LEN +
                        cfg.MOVINGMNIST.OUT_LEN,
                        random=False)
                    valid_data = torch.from_numpy(
                        np.array(valid_frame[0:cfg.MOVINGMNIST.IN_LEN,
                                             ...])).to(
                                                 cfg.GLOBAL.DEVICE) / 255.0
                    valid_label = torch.from_numpy(
                        valid_frame[cfg.MODEL.IN_LEN:(cfg.MOVINGMNIST.IN_LEN +
                                                      cfg.MOVINGMNIST.OUT_LEN),
                                    ...]).to(cfg.GLOBAL.DEVICE) / 255.0
                    output = encoder_forecaster(valid_data)
                    overall_mse += torch.mean((valid_label - output)**2)
            avg_mse = overall_mse / 10
            with open(os.path.join(base_dir, 'result.txt'), 'a') as f:
                f.write(str(avg_mse) + '\n')
            print(base_dir, avg_mse)
            gif_dir = os.path.join(base_dir, "gif")
            if not os.path.exists(gif_dir):
                os.mkdir(gif_dir)
            save_gif(output.detach().cpu().numpy()[:, 0, 0, :, :],
                     os.path.join(gif_dir, "pred-{}.gif".format(itera)))
            save_gif(train_data.detach().cpu().numpy()[:, 0, 0, :, :],
                     os.path.join(gif_dir, "in-{}.gif".format(itera)))
            save_gif(train_label.detach().cpu().numpy()[:, 0, 0, :, :],
                     os.path.join(gif_dir, "gt-{}.gif".format(itera)))

        if (itera + 1) % test_and_save_checkpoint_iterations == 0:
            torch.save(
                encoder_forecaster.state_dict(),
                osp.join(model_save_dir,
                         'encoder_forecaster_{}.pth'.format(itera)))
        itera += 1

    writer.close()
Ejemplo n.º 7
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="training")
    save_movingmnist_cfg(base_dir)
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))

    mnist_rnn = MovingMNISTFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE //
                                   len(args.ctx),
                                   in_seq_len=cfg.MODEL.IN_LEN,
                                   out_seq_len=cfg.MODEL.OUT_LEN)

    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Resume last checkpoint
    if args.resume:
        encoder_net.load(prefix=os.path.join(base_dir, 'encoder_net'),
                         epoch=latest_iter_id(base_dir),
                         load_optimizer_states=True,
                         data_names=[
                             'data', 'ebrnn1_begin_state_h',
                             'ebrnn2_begin_state_h', 'ebrnn3_begin_state_h'
                         ],
                         label_names=[])  # change it next time
        forecaster_net.load(prefix=os.path.join(base_dir, 'forecaster_net'),
                            epoch=latest_iter_id(base_dir),
                            load_optimizer_states=True,
                            data_names=[
                                'fbrnn1_begin_state_h', 'fbrnn2_begin_state_h',
                                'fbrnn3_begin_state_h'
                            ],
                            label_names=[])  # change it next time
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR,
                          load_iter=cfg.MODEL.LOAD_ITER,
                          encoder_net=encoder_net,
                          forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
    states.reset_all()
    for info in mnist_rnn.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN)
        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                   encoder_net=encoder_net,
                   forecaster_net=forecaster_net,
                   loss_net=loss_net,
                   init_states=states,
                   data_nd=data_nd,
                   gt_nd=target_nd,
                   mask_nd=None,
                   iter_id=iter_id)
        if (iter_id + 1) % 100 == 0:
            new_frame_dat, _ = mnist_iter.sample(
                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)
            data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                                  ctx=args.ctx[0]) / 255.0
            target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:(
                cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                    ctx=args.ctx[0]) / 255.0
            pred_nd = mnist_get_prediction(data_nd=data_nd,
                                           states=states,
                                           encoder_net=encoder_net,
                                           forecaster_net=forecaster_net)
            save_gif(pred_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "pred.gif"))
            save_gif(data_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "in.gif"))
            save_gif(target_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "gt.gif"))
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(prefix=os.path.join(
                base_dir, "encoder_net"),
                                        epoch=iter_id,
                                        save_optimizer_states=True)
            forecaster_net.save_checkpoint(prefix=os.path.join(
                base_dir, "forecaster_net"),
                                           epoch=iter_id,
                                           save_optimizer_states=True)
        iter_id += 1