예제 #1
0
def run_batch(batch_idx, val, batch_loader, tracker_cnn, criterion, optimizer, history, save_debug_image):
    """Train or validate on a single batch."""
    train = not val
    time_cbatch_start = time.time()
    inputs, outputs_gt = batch_loader.get_batch()
    if Config.GPU >= 0:
        inputs = to_cuda(to_variable(inputs, volatile=val), Config.GPU)
        outputs_gt_bins = to_cuda(to_variable(np.argmax(outputs_gt, axis=1), volatile=val, requires_grad=False), Config.GPU)
        outputs_gt = to_cuda(to_variable(outputs_gt, volatile=val, requires_grad=False), Config.GPU)
    time_cbatch_end = time.time()

    time_fwbw_start = time.time()
    if train:
        optimizer.zero_grad()
    outputs_pred = tracker_cnn(inputs)
    outputs_pred_sm = F.softmax(outputs_pred)
    loss = criterion(outputs_pred, outputs_gt_bins)
    if train:
        loss.backward()
        optimizer.step()
    time_fwbw_end = time.time()

    loss = loss.data.cpu().numpy()[0]
    outputs_pred_np = to_numpy(outputs_pred_sm)
    outputs_gt_np = to_numpy(outputs_gt)
    acc = np.sum(np.equal(np.argmax(outputs_pred_np, axis=1), np.argmax(outputs_gt_np, axis=1))) / BATCH_SIZE
    history.add_value("loss", "train" if train else "val", batch_idx, loss, average=val)
    history.add_value("acc", "train" if train else "val", batch_idx, acc, average=val)
    print("[%s] Batch %05d | loss %.8f | acc %.2f | cbatch %.04fs | fwbw %.04fs" % ("T" if train else "V", batch_idx, loss, acc, time_cbatch_end - time_cbatch_start, time_fwbw_end - time_fwbw_start))

    if save_debug_image:
        debug_img = generate_debug_image(inputs, outputs_gt, outputs_pred_sm)
        misc.imsave("debug_img_%s.jpg" % ("train" if train else "val"), debug_img)
예제 #2
0
    def embed_state(self,
                    previous_states,
                    state,
                    volatile=False,
                    requires_grad=True,
                    gpu=-1):
        prev_scrs = [
            self.downscale_prev(s.screenshot_rs) for s in previous_states
        ]
        prev_scrs_y = [
            cv2.cvtColor(scr, cv2.COLOR_RGB2GRAY) for scr in prev_scrs
        ]

        #inputs = np.dstack([self.downscale(state.screenshot_rs)] + list(reversed(prev_scrs_y)))
        inputs = np.array(self.downscale(state.screenshot_rs),
                          dtype=np.float32)
        inputs = inputs / 255.0
        inputs = inputs.transpose((2, 0, 1))
        inputs = inputs[np.newaxis, ...]
        inputs = to_cuda(
            to_variable(inputs, volatile=volatile,
                        requires_grad=requires_grad), gpu)

        inputs_prev = np.dstack(prev_scrs_y)
        inputs_prev = inputs_prev.astype(np.float32) / 255.0
        inputs_prev = inputs_prev.transpose((2, 0, 1))
        inputs_prev = inputs_prev[np.newaxis, ...]
        inputs_prev = to_cuda(
            to_variable(inputs_prev,
                        volatile=volatile,
                        requires_grad=requires_grad), gpu)

        return self.embed(inputs, inputs_prev)
예제 #3
0
 def create_hidden(self, batch_size, volatile=False, gpu=Config.GPU):
     weight = next(self.parameters()).data
     return (to_cuda(
         Variable(weight.new(self.nb_layers, batch_size,
                             self.hidden_size).zero_(),
                  volatile=volatile), gpu),
             to_cuda(
                 Variable(weight.new(self.nb_layers, batch_size,
                                     self.hidden_size).zero_(),
                          volatile=volatile), gpu))
예제 #4
0
def main():
    """Initialize/load model, dataset, optimizers, history and loss
    plotter, augmentation sequence. Then start training loop."""

    parser = argparse.ArgumentParser(description="Train semisupervised model")
    parser.add_argument('--nocontinue',
                        default=False,
                        action="store_true",
                        help="Whether to NOT continue the previous experiment",
                        required=False)
    parser.add_argument(
        '--withshortcuts',
        default=False,
        action="store_true",
        help=
        "Whether to train a model with shortcuts from downscaling to upscaling layers.",
        required=False)
    args = parser.parse_args()

    checkpoint_fp = "train_semisupervised_model%s.tar" % (
        "_withshortcuts" if args.withshortcuts else "", )
    if os.path.isfile(checkpoint_fp) and not args.nocontinue:
        checkpoint = torch.load(checkpoint_fp)
    else:
        checkpoint = None

    # load or initialize loss history
    if checkpoint is not None:
        history = plotting.History.from_string(checkpoint["history"])
    else:
        history = plotting.History()
        history.add_group("loss-ae", ["train", "val"], increasing=False)
        history.add_group("loss-grids", ["train", "val"], increasing=False)
        history.add_group("loss-atts", ["train", "val"], increasing=False)
        history.add_group("loss-multiactions", ["train", "val"],
                          increasing=False)
        history.add_group("loss-flow", ["train", "val"], increasing=False)
        history.add_group("loss-canny", ["train", "val"], increasing=False)
        history.add_group("loss-flipped", ["train", "val"], increasing=False)

    # initialize loss plotter
    loss_plotter = plotting.LossPlotter(
        history.get_group_names(),
        history.get_groups_increasing(),
        save_to_fp="train_semisupervised_plot%s.jpg" %
        ("_withshortcuts" if args.withshortcuts else "", ))
    loss_plotter.start_batch_idx = 100

    # initialize and load model
    predictor = models.Predictor(
    ) if not args.withshortcuts else models.PredictorWithShortcuts()
    if checkpoint is not None:
        predictor.load_state_dict(checkpoint["predictor_state_dict"])
    predictor.train()

    # initialize optimizer
    optimizer_predictor = optim.Adam(predictor.parameters())

    # initialize losses
    criterion_ae = nn.MSELoss()
    criterion_grids = nn.BCELoss()
    criterion_atts = nn.BCELoss()
    criterion_multiactions = nn.BCELoss()
    criterion_flow = nn.BCELoss()
    criterion_canny = nn.BCELoss()
    criterion_flipped = nn.BCELoss()

    # send everything to gpu
    if GPU >= 0:
        predictor.cuda(GPU)
        criterion_ae.cuda(GPU)
        criterion_grids.cuda(GPU)
        criterion_atts.cuda(GPU)
        criterion_multiactions.cuda(GPU)
        criterion_flow.cuda(GPU)
        criterion_canny.cuda(GPU)
        criterion_flipped.cuda(GPU)

    # initialize image augmentation cascade
    rarely = lambda aug: iaa.Sometimes(0.1, aug)
    sometimes = lambda aug: iaa.Sometimes(0.2, aug)
    often = lambda aug: iaa.Sometimes(0.3, aug)
    # no hflips here, because that would mess up the optimal steering direction
    # no grayscale here, because that doesn't play well with the grayscale
    # previous images
    # no coarse dropout, because then the model would have to magically guess
    # things like edges or flow
    augseq = iaa.Sequential(
        [
            often(iaa.Crop(percent=(0, 0.05))),
            sometimes(iaa.GaussianBlur(
                (0, 0.2))),  # blur images with a sigma between 0 and 3.0
            often(
                iaa.AdditiveGaussianNoise(
                    loc=0, scale=(0.0, 0.01 * 255),
                    per_channel=0.5)),  # add gaussian noise to images
            often(iaa.Dropout((0.0, 0.05), per_channel=0.5)),
            rarely(iaa.Sharpen(alpha=(0, 0.7),
                               lightness=(0.75, 1.5))),  # sharpen images
            rarely(iaa.Emboss(alpha=(0, 0.7),
                              strength=(0, 2.0))),  # emboss images
            rarely(
                iaa.Sometimes(
                    0.5,
                    iaa.EdgeDetect(alpha=(0, 0.4)),
                    iaa.DirectedEdgeDetect(alpha=(0, 0.4),
                                           direction=(0.0, 1.0)),
                )),
            often(iaa.Add(
                (-20, 20), per_channel=0.5
            )),  # change brightness of images (by -10 to 10 of original value)
            often(iaa.Multiply((0.8, 1.2), per_channel=0.25)
                  ),  # change brightness of images (50-150% of original value)
            often(iaa.ContrastNormalization(
                (0.8, 1.2),
                per_channel=0.5)),  # improve or worsen the contrast
            sometimes(
                iaa.Affine(scale={
                    "x": (0.9, 1.1),
                    "y": (0.9, 1.1)
                },
                           translate_percent={
                               "x": (-0.07, 0.07),
                               "y": (-0.07, 0.07)
                           },
                           rotate=(0, 0),
                           shear=(0, 0),
                           order=[0, 1],
                           cval=(0, 255),
                           mode=ia.ALL))
        ],
        random_order=True  # do all of the above in random order
    )

    # load datasets
    print("Loading dataset...")
    if USE_COMPRESSED_ANNOTATIONS:
        examples = load_dataset_annotated_compressed()
    else:
        examples = load_dataset_annotated()
    #examples_annotated_ids = set([ex.state_idx for ex in examples])
    examples_annotated_ids = set()
    examples_autogen_val = load_dataset_autogen(val=True,
                                                nb_load=NB_AUTOGEN_VAL,
                                                not_in=examples_annotated_ids)
    examples_autogen_train = load_dataset_autogen(
        val=False, nb_load=NB_AUTOGEN_TRAIN, not_in=examples_annotated_ids)
    random.shuffle(examples)
    random.shuffle(examples_autogen_val)
    random.shuffle(examples_autogen_train)
    examples_val = examples[0:NB_VAL_SPLIT]
    examples_train = examples[NB_VAL_SPLIT:]

    # initialize background batch loaders
    #memory = replay_memory.ReplayMemory.get_instance_supervised()
    batch_loader_train = BatchLoader(examples_train,
                                     examples_autogen_train,
                                     augseq=augseq,
                                     queue_size=15,
                                     nb_workers=4,
                                     threaded=False)
    batch_loader_val = BatchLoader(examples_val,
                                   examples_autogen_val,
                                   augseq=iaa.Noop(),
                                   queue_size=NB_VAL_BATCHES,
                                   nb_workers=1,
                                   threaded=False)

    # training loop
    print("Training...")
    start_batch_idx = 0 if checkpoint is None else checkpoint["batch_idx"] + 1
    for batch_idx in xrange(start_batch_idx, NB_BATCHES):
        # train on batch

        # load batch data
        time_cbatch_start = time.time()
        (inputs,
         inputs_prev), (outputs_ae_gt, outputs_grids_gt_orig,
                        outputs_atts_gt_orig, outputs_multiactions_gt,
                        outputs_flow_gt, outputs_canny_gt,
                        outputs_flipped_gt), (
                            grids_annotated,
                            atts_annotated) = batch_loader_train.get_batch()
        inputs = to_cuda(to_variable(inputs), GPU)
        inputs_prev = to_cuda(to_variable(inputs_prev), GPU)
        outputs_ae_gt = to_cuda(
            to_variable(outputs_ae_gt, requires_grad=False), GPU)
        outputs_multiactions_gt = to_cuda(
            to_variable(outputs_multiactions_gt, requires_grad=False), GPU)
        outputs_flow_gt = to_cuda(
            to_variable(outputs_flow_gt, requires_grad=False), GPU)
        outputs_canny_gt = to_cuda(
            to_variable(outputs_canny_gt, requires_grad=False), GPU)
        outputs_flipped_gt = to_cuda(
            to_variable(outputs_flipped_gt, requires_grad=False), GPU)
        time_cbatch_end = time.time()

        # predict and compute losses
        time_fwbw_start = time.time()
        optimizer_predictor.zero_grad()
        (outputs_ae_pred, outputs_grids_pred, outputs_atts_pred,
         outputs_multiactions_pred, outputs_flow_pred, outputs_canny_pred,
         outputs_flipped_pred, emb) = predictor(inputs, inputs_prev)
        # zero-grad some outputs where annotations are not available for specific examples
        outputs_grids_gt = remove_unannotated_grids_gt(outputs_grids_pred,
                                                       outputs_grids_gt_orig,
                                                       grids_annotated)
        outputs_grids_gt = to_cuda(
            to_variable(outputs_grids_gt, requires_grad=False), GPU)
        outputs_atts_gt = remove_unannotated_atts_gt(outputs_atts_pred,
                                                     outputs_atts_gt_orig,
                                                     atts_annotated)
        outputs_atts_gt = to_cuda(
            to_variable(outputs_atts_gt, requires_grad=False), GPU)
        loss_ae = criterion_ae(outputs_ae_pred, outputs_ae_gt)
        loss_grids = criterion_grids(outputs_grids_pred, outputs_grids_gt)
        loss_atts = criterion_atts(outputs_atts_pred, outputs_atts_gt)
        loss_multiactions = criterion_multiactions(outputs_multiactions_pred,
                                                   outputs_multiactions_gt)
        loss_flow = criterion_flow(outputs_flow_pred, outputs_flow_gt)
        loss_canny = criterion_canny(outputs_canny_pred, outputs_canny_gt)
        loss_flipped = criterion_flipped(outputs_flipped_pred,
                                         outputs_flipped_gt)
        losses_grad_lst = [
            loss.data.new().resize_as_(loss.data).fill_(w) for loss, w in zip([
                loss_ae, loss_grids, loss_atts, loss_multiactions, loss_flow,
                loss_canny, loss_flipped
            ], [
                LOSS_AE_WEIGHTING, LOSS_GRIDS_WEIGHTING,
                LOSS_ATTRIBUTES_WEIGHTING, LOSS_MULTIACTIONS_WEIGHTING,
                LOSS_FLOW_WEIGHTING, LOSS_CANNY_WEIGHTING,
                LOSS_FLIPPED_WEIGHTING
            ])
        ]
        torch.autograd.backward([
            loss_ae, loss_grids, loss_atts, loss_multiactions, loss_flow,
            loss_canny, loss_flipped
        ], losses_grad_lst)
        optimizer_predictor.step()
        time_fwbw_end = time.time()

        # add losses to history and output a message
        loss_ae_value = to_numpy(loss_ae)[0]
        loss_grids_value = to_numpy(loss_grids)[0]
        loss_atts_value = to_numpy(loss_atts)[0]
        loss_multiactions_value = to_numpy(loss_multiactions)[0]
        loss_flow_value = to_numpy(loss_flow)[0]
        loss_canny_value = to_numpy(loss_canny)[0]
        loss_flipped_value = to_numpy(loss_flipped)[0]
        history.add_value("loss-ae", "train", batch_idx, loss_ae_value)
        history.add_value("loss-grids", "train", batch_idx, loss_grids_value)
        history.add_value("loss-atts", "train", batch_idx, loss_atts_value)
        history.add_value("loss-multiactions", "train", batch_idx,
                          loss_multiactions_value)
        history.add_value("loss-flow", "train", batch_idx, loss_flow_value)
        history.add_value("loss-canny", "train", batch_idx, loss_canny_value)
        history.add_value("loss-flipped", "train", batch_idx,
                          loss_flipped_value)
        print(
            "[T] Batch %05d L[ae=%.4f, grids=%.4f, atts=%.4f, multiactions=%.4f, flow=%.4f, canny=%.4f, flipped=%.4f] T[cbatch=%.04fs, fwbw=%.04fs]"
            % (batch_idx, loss_ae_value, loss_grids_value, loss_atts_value,
               loss_multiactions_value, loss_flow_value, loss_canny_value,
               loss_flipped_value, time_cbatch_end - time_cbatch_start,
               time_fwbw_end - time_fwbw_start))

        # genrate a debug image showing batch predictions and ground truths
        if (batch_idx + 1) % 20 == 0:
            debug_img = generate_debug_image(
                inputs, inputs_prev, outputs_ae_gt, outputs_grids_gt_orig,
                outputs_atts_gt_orig, outputs_multiactions_gt, outputs_flow_gt,
                outputs_canny_gt, outputs_flipped_gt, outputs_ae_pred,
                outputs_grids_pred, outputs_atts_pred,
                outputs_multiactions_pred, outputs_flow_pred,
                outputs_canny_pred, outputs_flipped_pred, grids_annotated,
                atts_annotated)
            misc.imsave(
                "train_semisupervised_debug_img%s.jpg" %
                ("_withshortcuts" if args.withshortcuts else "", ), debug_img)

        # run N validation batches
        # TODO merge this with training stuff above (one function for both)
        if (batch_idx + 1) % VAL_EVERY == 0:
            predictor.eval()
            loss_ae_total = 0
            loss_grids_total = 0
            loss_atts_total = 0
            loss_multiactions_total = 0
            loss_flow_total = 0
            loss_canny_total = 0
            loss_flipped_total = 0
            for i in xrange(NB_VAL_BATCHES):
                time_cbatch_start = time.time()
                (inputs, inputs_prev), (
                    outputs_ae_gt, outputs_grids_gt_orig, outputs_atts_gt_orig,
                    outputs_multiactions_gt, outputs_flow_gt, outputs_canny_gt,
                    outputs_flipped_gt), (
                        grids_annotated,
                        atts_annotated) = batch_loader_val.get_batch()
                inputs = to_cuda(to_variable(inputs, volatile=True), GPU)
                inputs_prev = to_cuda(to_variable(inputs_prev, volatile=True),
                                      GPU)
                outputs_ae_gt = to_cuda(
                    to_variable(outputs_ae_gt, volatile=True), GPU)
                outputs_multiactions_gt = to_cuda(
                    to_variable(outputs_multiactions_gt, volatile=True), GPU)
                outputs_flow_gt = to_cuda(
                    to_variable(outputs_flow_gt, volatile=True), GPU)
                outputs_canny_gt = to_cuda(
                    to_variable(outputs_canny_gt, volatile=True), GPU)
                outputs_flipped_gt = to_cuda(
                    to_variable(outputs_flipped_gt, volatile=True), GPU)
                time_cbatch_end = time.time()

                time_fwbw_start = time.time()
                (outputs_ae_pred, outputs_grids_pred, outputs_atts_pred,
                 outputs_multiactions_pred, outputs_flow_pred,
                 outputs_canny_pred, outputs_flipped_pred,
                 emb) = predictor(inputs, inputs_prev)
                outputs_grids_gt = remove_unannotated_grids_gt(
                    outputs_grids_pred, outputs_grids_gt_orig, grids_annotated)
                outputs_grids_gt = to_cuda(
                    to_variable(outputs_grids_gt, volatile=True), GPU)
                outputs_atts_gt = remove_unannotated_atts_gt(
                    outputs_atts_pred, outputs_atts_gt_orig, atts_annotated)
                outputs_atts_gt = to_cuda(
                    to_variable(outputs_atts_gt, volatile=True), GPU)
                loss_ae = criterion_ae(outputs_ae_pred, outputs_ae_gt)
                loss_grids = criterion_grids(outputs_grids_pred,
                                             outputs_grids_gt)
                loss_atts = criterion_atts(outputs_atts_pred, outputs_atts_gt)
                loss_multiactions = criterion_multiactions(
                    outputs_multiactions_pred, outputs_multiactions_gt)
                loss_flow = criterion_flow(outputs_flow_pred, outputs_flow_gt)
                loss_canny = criterion_canny(outputs_canny_pred,
                                             outputs_canny_gt)
                loss_flipped = criterion_flipped(outputs_flipped_pred,
                                                 outputs_flipped_gt)
                time_fwbw_end = time.time()

                loss_ae_value = to_numpy(loss_ae)[0]
                loss_grids_value = to_numpy(loss_grids)[0]
                loss_atts_value = to_numpy(loss_atts)[0]
                loss_multiactions_value = to_numpy(loss_multiactions)[0]
                loss_flow_value = to_numpy(loss_flow)[0]
                loss_canny_value = to_numpy(loss_canny)[0]
                loss_flipped_value = to_numpy(loss_flipped)[0]
                loss_ae_total += loss_ae_value
                loss_grids_total += loss_grids_value
                loss_atts_total += loss_atts_value
                loss_multiactions_total += loss_multiactions_value
                loss_flow_total += loss_flow_value
                loss_canny_total += loss_canny_value
                loss_flipped_total += loss_flipped_value
                print(
                    "[V] Batch %05d L[ae=%.4f, grids=%.4f, atts=%.4f, multiactions=%.4f, flow=%.4f, canny=%.4f, flipped=%.4f] T[cbatch=%.04fs, fwbw=%.04fs]"
                    %
                    (batch_idx, loss_ae_value, loss_grids_value,
                     loss_atts_value, loss_multiactions_value, loss_flow_value,
                     loss_canny_value, loss_flipped_value, time_cbatch_end -
                     time_cbatch_start, time_fwbw_end - time_fwbw_start))

                if i == 0:
                    debug_img = generate_debug_image(
                        inputs, inputs_prev, outputs_ae_gt,
                        outputs_grids_gt_orig, outputs_atts_gt_orig,
                        outputs_multiactions_gt, outputs_flow_gt,
                        outputs_canny_gt, outputs_flipped_gt, outputs_ae_pred,
                        outputs_grids_pred, outputs_atts_pred,
                        outputs_multiactions_pred, outputs_flow_pred,
                        outputs_canny_pred, outputs_flipped_pred,
                        grids_annotated, atts_annotated)
                    misc.imsave(
                        "train_semisupervised_debug_img_val%s.jpg" %
                        ("_withshortcuts" if args.withshortcuts else "", ),
                        debug_img)
            history.add_value("loss-ae", "val", batch_idx,
                              loss_ae_total / NB_VAL_BATCHES)
            history.add_value("loss-grids", "val", batch_idx,
                              loss_grids_total / NB_VAL_BATCHES)
            history.add_value("loss-atts", "val", batch_idx,
                              loss_atts_total / NB_VAL_BATCHES)
            history.add_value("loss-multiactions", "val", batch_idx,
                              loss_multiactions_total / NB_VAL_BATCHES)
            history.add_value("loss-flow", "val", batch_idx,
                              loss_flow_total / NB_VAL_BATCHES)
            history.add_value("loss-canny", "val", batch_idx,
                              loss_canny_total / NB_VAL_BATCHES)
            history.add_value("loss-flipped", "val", batch_idx,
                              loss_flipped_total / NB_VAL_BATCHES)
            predictor.train()

        # generate loss plot
        if (batch_idx + 1) % PLOT_EVERY == 0:
            loss_plotter.plot(history)

        # every N batches, save a checkpoint
        if (batch_idx + 1) % SAVE_EVERY == 0:
            checkpoint_fp = "train_semisupervised_model%s.tar" % (
                "_withshortcuts" if args.withshortcuts else "", )
            torch.save(
                {
                    "batch_idx": batch_idx,
                    "history": history.to_string(),
                    "predictor_state_dict": predictor.state_dict(),
                }, checkpoint_fp)

        # refresh automatically generated examples (autoencoder, canny edge stuff etc.)
        if (batch_idx + 1) % 1000 == 0:
            print("Refreshing autogen dataset...")
            batch_loader_train.join()
            examples_autogen_train = load_dataset_autogen(
                val=False,
                nb_load=NB_AUTOGEN_TRAIN,
                not_in=examples_annotated_ids)
            batch_loader_train = BatchLoader(examples_train,
                                             examples_autogen_train,
                                             augseq=augseq,
                                             queue_size=15,
                                             nb_workers=4,
                                             threaded=False)
예제 #5
0
 def forward_image(self, subimg, softmax=False, volatile=False, requires_grad=True, gpu=GPU):
     subimg = np.float32([subimg/255]).transpose((0, 3, 1, 2))
     subimg = to_cuda(to_variable(subimg, volatile=volatile, requires_grad=requires_grad), GPU)
     return self.forward(subimg, softmax=softmax)
예제 #6
0
    def forward(self, inputs, inputs_prev, only_embed=False):
        def act(x):
            return F.relu(x, inplace=True)

        def lrelu(x, negative_slope=0.2):
            return F.leaky_relu(x, negative_slope=negative_slope, inplace=True)

        def up(x, f=2):
            m = nn.UpsamplingNearest2d(scale_factor=f)
            return m(x)

        def maxp(x):
            return F.max_pool2d(x, 2)

        B = inputs.size(0)
        pos_x = np.tile(
            np.linspace(0, 1, 40).astype(np.float32).reshape(1, 1, 40),
            (B, 1, 23, 1))
        pos_x = np.concatenate([pos_x, np.fliplr(pos_x)], axis=1)
        pos_y = np.tile(
            np.linspace(0, 1, 23).astype(np.float32).reshape(1, 23, 1),
            (B, 1, 1, 40))
        pos_y = np.concatenate([pos_y, np.flipud(pos_y)], axis=1)

        pos_x = to_cuda(
            to_variable(pos_x,
                        volatile=inputs.volatile,
                        requires_grad=inputs.requires_grad), Config.GPU)
        pos_y = to_cuda(
            to_variable(pos_y,
                        volatile=inputs.volatile,
                        requires_grad=inputs.requires_grad), Config.GPU)

        x_emb0_curr = inputs  # 3x90x160
        x_emb1_curr = lrelu(
            self.emb_c1_sd_curr(
                self.emb_c1_bn_curr(self.emb_c1_curr(x_emb0_curr))))  # 45x80
        x_emb2_curr = lrelu(
            self.emb_c2_sd_curr(
                self.emb_c2_bn_curr(self.emb_c2_curr(x_emb1_curr))))  # 45x80
        x_emb2_curr = F.pad(x_emb2_curr, (0, 0, 1, 0))  # 45x80 -> 46x80
        x_emb2_curr_pool = maxp(x_emb2_curr)  # 23x40
        x_emb3_curr = lrelu(
            self.emb_c3_sd_curr(
                self.emb_c3_bn_curr(
                    self.emb_c3_curr(x_emb2_curr_pool))))  # 23x40

        x_emb0_prev = inputs_prev  # 2x45x80
        x_emb1_prev = lrelu(
            self.emb_c1_sd_prev(
                self.emb_c1_bn_prev(self.emb_c1_prev(x_emb0_prev))))  # 45x80
        x_emb1_prev = F.pad(x_emb1_prev, (0, 0, 1, 0))  # 45x80 -> 46x80
        x_emb1_prev = maxp(x_emb1_prev)  # 23x40
        x_emb2_prev = lrelu(
            self.emb_c2_sd_prev(
                self.emb_c2_bn_prev(self.emb_c2_prev(x_emb1_prev))))  # 23x40

        x_emb3 = torch.cat([x_emb3_curr, x_emb2_prev, pos_x, pos_y], 1)
        x_emb3 = F.pad(x_emb3, (0, 0, 1, 0))  # 23x40 -> 24x40

        x_emb4 = lrelu(self.emb_c4_sd(self.emb_c4_bn(
            self.emb_c4(x_emb3))))  # 12x20
        x_emb5 = lrelu(self.emb_c5_sd(self.emb_c5_bn(
            self.emb_c5(x_emb4))))  # 6x10
        x_emb6 = lrelu(self.emb_c6_sd(self.emb_c6_bn(
            self.emb_c6(x_emb5))))  # 3x5
        x_emb7 = lrelu(self.emb_c7_sd(self.emb_c7_bn(
            self.emb_c7(x_emb6))))  # 3x5
        x_emb = x_emb7

        if only_embed:
            return x_emb
        else:
            x_maps = x_emb  # 3x5
            x_maps = up(x_maps, 4)  # 12x20
            x_maps = lrelu(
                self.maps_c1_bn(self.maps_c1(torch.cat([x_maps, x_emb4],
                                                       1))))  # 12x20
            x_maps = up(x_maps, 4)  # 48x80
            x_maps = lrelu(
                self.maps_c2_bn(
                    self.maps_c2(
                        torch.cat(
                            [x_maps, F.pad(x_emb2_curr, (0, 0, 1, 1))],
                            1))))  # 48x80 -> 44x80
            x_maps = F.pad(x_maps, (0, 0, 1, 0))  # 45x80
            x_maps = up(x_maps)  # 90x160
            x_maps = F.sigmoid(self.maps_c3(torch.cat([x_maps, inputs],
                                                      1)))  # 90x160

            ae_size = 3 + self.nb_previous_images
            x_grids = x_maps[:, 0:8, ...]
            x_ae = x_maps[:, 8:8 + ae_size, ...]
            x_flow = x_maps[:, 8 + ae_size:8 + ae_size + 1, ...]
            x_canny = x_maps[:, 8 + ae_size + 1:8 + ae_size + 2, ...]

            x_vec = x_emb
            x_vec = x_vec.view(-1, 512 * 3 * 5)
            x_vec = F.dropout(x_vec, p=0.5, training=self.training)
            x_vec = F.sigmoid(self.vec_fc1(x_vec))

            atts_size = 10 + 7 + 3 + 5 + 8 + 4 + 4 + 4 + 3
            ma_size = 9 + 9 + 9 + 9
            x_atts = x_vec[:, 0:atts_size]
            x_ma = x_vec[:, atts_size:atts_size + ma_size]
            x_flipped = x_vec[:, atts_size + ma_size:]

            return x_ae, x_grids, x_atts, x_ma, x_flow, x_canny, x_flipped, x_emb