def run_batch(batch_idx, val, batch_loader, tracker_cnn, criterion, optimizer, history, save_debug_image): """Train or validate on a single batch.""" train = not val time_cbatch_start = time.time() inputs, outputs_gt = batch_loader.get_batch() if Config.GPU >= 0: inputs = to_cuda(to_variable(inputs, volatile=val), Config.GPU) outputs_gt_bins = to_cuda(to_variable(np.argmax(outputs_gt, axis=1), volatile=val, requires_grad=False), Config.GPU) outputs_gt = to_cuda(to_variable(outputs_gt, volatile=val, requires_grad=False), Config.GPU) time_cbatch_end = time.time() time_fwbw_start = time.time() if train: optimizer.zero_grad() outputs_pred = tracker_cnn(inputs) outputs_pred_sm = F.softmax(outputs_pred) loss = criterion(outputs_pred, outputs_gt_bins) if train: loss.backward() optimizer.step() time_fwbw_end = time.time() loss =[0] outputs_pred_np = to_numpy(outputs_pred_sm) outputs_gt_np = to_numpy(outputs_gt) acc = np.sum(np.equal(np.argmax(outputs_pred_np, axis=1), np.argmax(outputs_gt_np, axis=1))) / BATCH_SIZE history.add_value("loss", "train" if train else "val", batch_idx, loss, average=val) history.add_value("acc", "train" if train else "val", batch_idx, acc, average=val) print("[%s] Batch %05d | loss %.8f | acc %.2f | cbatch %.04fs | fwbw %.04fs" % ("T" if train else "V", batch_idx, loss, acc, time_cbatch_end - time_cbatch_start, time_fwbw_end - time_fwbw_start)) if save_debug_image: debug_img = generate_debug_image(inputs, outputs_gt, outputs_pred_sm) misc.imsave("debug_img_%s.jpg" % ("train" if train else "val"), debug_img)
def embed_state(self, previous_states, state, volatile=False, requires_grad=True, gpu=-1): prev_scrs = [ self.downscale_prev(s.screenshot_rs) for s in previous_states ] prev_scrs_y = [ cv2.cvtColor(scr, cv2.COLOR_RGB2GRAY) for scr in prev_scrs ] #inputs = np.dstack([self.downscale(state.screenshot_rs)] + list(reversed(prev_scrs_y))) inputs = np.array(self.downscale(state.screenshot_rs), dtype=np.float32) inputs = inputs / 255.0 inputs = inputs.transpose((2, 0, 1)) inputs = inputs[np.newaxis, ...] inputs = to_cuda( to_variable(inputs, volatile=volatile, requires_grad=requires_grad), gpu) inputs_prev = np.dstack(prev_scrs_y) inputs_prev = inputs_prev.astype(np.float32) / 255.0 inputs_prev = inputs_prev.transpose((2, 0, 1)) inputs_prev = inputs_prev[np.newaxis, ...] inputs_prev = to_cuda( to_variable(inputs_prev, volatile=volatile, requires_grad=requires_grad), gpu) return self.embed(inputs, inputs_prev)
def create_hidden(self, batch_size, volatile=False, gpu=Config.GPU): weight = next(self.parameters()).data return (to_cuda( Variable(, batch_size, self.hidden_size).zero_(), volatile=volatile), gpu), to_cuda( Variable(, batch_size, self.hidden_size).zero_(), volatile=volatile), gpu))
def main(): """Initialize/load model, dataset, optimizers, history and loss plotter, augmentation sequence. Then start training loop.""" parser = argparse.ArgumentParser(description="Train semisupervised model") parser.add_argument('--nocontinue', default=False, action="store_true", help="Whether to NOT continue the previous experiment", required=False) parser.add_argument( '--withshortcuts', default=False, action="store_true", help= "Whether to train a model with shortcuts from downscaling to upscaling layers.", required=False) args = parser.parse_args() checkpoint_fp = "train_semisupervised_model%s.tar" % ( "_withshortcuts" if args.withshortcuts else "", ) if os.path.isfile(checkpoint_fp) and not args.nocontinue: checkpoint = torch.load(checkpoint_fp) else: checkpoint = None # load or initialize loss history if checkpoint is not None: history = plotting.History.from_string(checkpoint["history"]) else: history = plotting.History() history.add_group("loss-ae", ["train", "val"], increasing=False) history.add_group("loss-grids", ["train", "val"], increasing=False) history.add_group("loss-atts", ["train", "val"], increasing=False) history.add_group("loss-multiactions", ["train", "val"], increasing=False) history.add_group("loss-flow", ["train", "val"], increasing=False) history.add_group("loss-canny", ["train", "val"], increasing=False) history.add_group("loss-flipped", ["train", "val"], increasing=False) # initialize loss plotter loss_plotter = plotting.LossPlotter( history.get_group_names(), history.get_groups_increasing(), save_to_fp="train_semisupervised_plot%s.jpg" % ("_withshortcuts" if args.withshortcuts else "", )) loss_plotter.start_batch_idx = 100 # initialize and load model predictor = models.Predictor( ) if not args.withshortcuts else models.PredictorWithShortcuts() if checkpoint is not None: predictor.load_state_dict(checkpoint["predictor_state_dict"]) predictor.train() # initialize optimizer optimizer_predictor = optim.Adam(predictor.parameters()) # initialize losses criterion_ae = nn.MSELoss() criterion_grids = nn.BCELoss() criterion_atts = nn.BCELoss() criterion_multiactions = nn.BCELoss() criterion_flow = nn.BCELoss() criterion_canny = nn.BCELoss() criterion_flipped = nn.BCELoss() # send everything to gpu if GPU >= 0: predictor.cuda(GPU) criterion_ae.cuda(GPU) criterion_grids.cuda(GPU) criterion_atts.cuda(GPU) criterion_multiactions.cuda(GPU) criterion_flow.cuda(GPU) criterion_canny.cuda(GPU) criterion_flipped.cuda(GPU) # initialize image augmentation cascade rarely = lambda aug: iaa.Sometimes(0.1, aug) sometimes = lambda aug: iaa.Sometimes(0.2, aug) often = lambda aug: iaa.Sometimes(0.3, aug) # no hflips here, because that would mess up the optimal steering direction # no grayscale here, because that doesn't play well with the grayscale # previous images # no coarse dropout, because then the model would have to magically guess # things like edges or flow augseq = iaa.Sequential( [ often(iaa.Crop(percent=(0, 0.05))), sometimes(iaa.GaussianBlur( (0, 0.2))), # blur images with a sigma between 0 and 3.0 often( iaa.AdditiveGaussianNoise( loc=0, scale=(0.0, 0.01 * 255), per_channel=0.5)), # add gaussian noise to images often(iaa.Dropout((0.0, 0.05), per_channel=0.5)), rarely(iaa.Sharpen(alpha=(0, 0.7), lightness=(0.75, 1.5))), # sharpen images rarely(iaa.Emboss(alpha=(0, 0.7), strength=(0, 2.0))), # emboss images rarely( iaa.Sometimes( 0.5, iaa.EdgeDetect(alpha=(0, 0.4)), iaa.DirectedEdgeDetect(alpha=(0, 0.4), direction=(0.0, 1.0)), )), often(iaa.Add( (-20, 20), per_channel=0.5 )), # change brightness of images (by -10 to 10 of original value) often(iaa.Multiply((0.8, 1.2), per_channel=0.25) ), # change brightness of images (50-150% of original value) often(iaa.ContrastNormalization( (0.8, 1.2), per_channel=0.5)), # improve or worsen the contrast sometimes( iaa.Affine(scale={ "x": (0.9, 1.1), "y": (0.9, 1.1) }, translate_percent={ "x": (-0.07, 0.07), "y": (-0.07, 0.07) }, rotate=(0, 0), shear=(0, 0), order=[0, 1], cval=(0, 255), mode=ia.ALL)) ], random_order=True # do all of the above in random order ) # load datasets print("Loading dataset...") if USE_COMPRESSED_ANNOTATIONS: examples = load_dataset_annotated_compressed() else: examples = load_dataset_annotated() #examples_annotated_ids = set([ex.state_idx for ex in examples]) examples_annotated_ids = set() examples_autogen_val = load_dataset_autogen(val=True, nb_load=NB_AUTOGEN_VAL, not_in=examples_annotated_ids) examples_autogen_train = load_dataset_autogen( val=False, nb_load=NB_AUTOGEN_TRAIN, not_in=examples_annotated_ids) random.shuffle(examples) random.shuffle(examples_autogen_val) random.shuffle(examples_autogen_train) examples_val = examples[0:NB_VAL_SPLIT] examples_train = examples[NB_VAL_SPLIT:] # initialize background batch loaders #memory = replay_memory.ReplayMemory.get_instance_supervised() batch_loader_train = BatchLoader(examples_train, examples_autogen_train, augseq=augseq, queue_size=15, nb_workers=4, threaded=False) batch_loader_val = BatchLoader(examples_val, examples_autogen_val, augseq=iaa.Noop(), queue_size=NB_VAL_BATCHES, nb_workers=1, threaded=False) # training loop print("Training...") start_batch_idx = 0 if checkpoint is None else checkpoint["batch_idx"] + 1 for batch_idx in xrange(start_batch_idx, NB_BATCHES): # train on batch # load batch data time_cbatch_start = time.time() (inputs, inputs_prev), (outputs_ae_gt, outputs_grids_gt_orig, outputs_atts_gt_orig, outputs_multiactions_gt, outputs_flow_gt, outputs_canny_gt, outputs_flipped_gt), ( grids_annotated, atts_annotated) = batch_loader_train.get_batch() inputs = to_cuda(to_variable(inputs), GPU) inputs_prev = to_cuda(to_variable(inputs_prev), GPU) outputs_ae_gt = to_cuda( to_variable(outputs_ae_gt, requires_grad=False), GPU) outputs_multiactions_gt = to_cuda( to_variable(outputs_multiactions_gt, requires_grad=False), GPU) outputs_flow_gt = to_cuda( to_variable(outputs_flow_gt, requires_grad=False), GPU) outputs_canny_gt = to_cuda( to_variable(outputs_canny_gt, requires_grad=False), GPU) outputs_flipped_gt = to_cuda( to_variable(outputs_flipped_gt, requires_grad=False), GPU) time_cbatch_end = time.time() # predict and compute losses time_fwbw_start = time.time() optimizer_predictor.zero_grad() (outputs_ae_pred, outputs_grids_pred, outputs_atts_pred, outputs_multiactions_pred, outputs_flow_pred, outputs_canny_pred, outputs_flipped_pred, emb) = predictor(inputs, inputs_prev) # zero-grad some outputs where annotations are not available for specific examples outputs_grids_gt = remove_unannotated_grids_gt(outputs_grids_pred, outputs_grids_gt_orig, grids_annotated) outputs_grids_gt = to_cuda( to_variable(outputs_grids_gt, requires_grad=False), GPU) outputs_atts_gt = remove_unannotated_atts_gt(outputs_atts_pred, outputs_atts_gt_orig, atts_annotated) outputs_atts_gt = to_cuda( to_variable(outputs_atts_gt, requires_grad=False), GPU) loss_ae = criterion_ae(outputs_ae_pred, outputs_ae_gt) loss_grids = criterion_grids(outputs_grids_pred, outputs_grids_gt) loss_atts = criterion_atts(outputs_atts_pred, outputs_atts_gt) loss_multiactions = criterion_multiactions(outputs_multiactions_pred, outputs_multiactions_gt) loss_flow = criterion_flow(outputs_flow_pred, outputs_flow_gt) loss_canny = criterion_canny(outputs_canny_pred, outputs_canny_gt) loss_flipped = criterion_flipped(outputs_flipped_pred, outputs_flipped_gt) losses_grad_lst = [ for loss, w in zip([ loss_ae, loss_grids, loss_atts, loss_multiactions, loss_flow, loss_canny, loss_flipped ], [ LOSS_AE_WEIGHTING, LOSS_GRIDS_WEIGHTING, LOSS_ATTRIBUTES_WEIGHTING, LOSS_MULTIACTIONS_WEIGHTING, LOSS_FLOW_WEIGHTING, LOSS_CANNY_WEIGHTING, LOSS_FLIPPED_WEIGHTING ]) ] torch.autograd.backward([ loss_ae, loss_grids, loss_atts, loss_multiactions, loss_flow, loss_canny, loss_flipped ], losses_grad_lst) optimizer_predictor.step() time_fwbw_end = time.time() # add losses to history and output a message loss_ae_value = to_numpy(loss_ae)[0] loss_grids_value = to_numpy(loss_grids)[0] loss_atts_value = to_numpy(loss_atts)[0] loss_multiactions_value = to_numpy(loss_multiactions)[0] loss_flow_value = to_numpy(loss_flow)[0] loss_canny_value = to_numpy(loss_canny)[0] loss_flipped_value = to_numpy(loss_flipped)[0] history.add_value("loss-ae", "train", batch_idx, loss_ae_value) history.add_value("loss-grids", "train", batch_idx, loss_grids_value) history.add_value("loss-atts", "train", batch_idx, loss_atts_value) history.add_value("loss-multiactions", "train", batch_idx, loss_multiactions_value) history.add_value("loss-flow", "train", batch_idx, loss_flow_value) history.add_value("loss-canny", "train", batch_idx, loss_canny_value) history.add_value("loss-flipped", "train", batch_idx, loss_flipped_value) print( "[T] Batch %05d L[ae=%.4f, grids=%.4f, atts=%.4f, multiactions=%.4f, flow=%.4f, canny=%.4f, flipped=%.4f] T[cbatch=%.04fs, fwbw=%.04fs]" % (batch_idx, loss_ae_value, loss_grids_value, loss_atts_value, loss_multiactions_value, loss_flow_value, loss_canny_value, loss_flipped_value, time_cbatch_end - time_cbatch_start, time_fwbw_end - time_fwbw_start)) # genrate a debug image showing batch predictions and ground truths if (batch_idx + 1) % 20 == 0: debug_img = generate_debug_image( inputs, inputs_prev, outputs_ae_gt, outputs_grids_gt_orig, outputs_atts_gt_orig, outputs_multiactions_gt, outputs_flow_gt, outputs_canny_gt, outputs_flipped_gt, outputs_ae_pred, outputs_grids_pred, outputs_atts_pred, outputs_multiactions_pred, outputs_flow_pred, outputs_canny_pred, outputs_flipped_pred, grids_annotated, atts_annotated) misc.imsave( "train_semisupervised_debug_img%s.jpg" % ("_withshortcuts" if args.withshortcuts else "", ), debug_img) # run N validation batches # TODO merge this with training stuff above (one function for both) if (batch_idx + 1) % VAL_EVERY == 0: predictor.eval() loss_ae_total = 0 loss_grids_total = 0 loss_atts_total = 0 loss_multiactions_total = 0 loss_flow_total = 0 loss_canny_total = 0 loss_flipped_total = 0 for i in xrange(NB_VAL_BATCHES): time_cbatch_start = time.time() (inputs, inputs_prev), ( outputs_ae_gt, outputs_grids_gt_orig, outputs_atts_gt_orig, outputs_multiactions_gt, outputs_flow_gt, outputs_canny_gt, outputs_flipped_gt), ( grids_annotated, atts_annotated) = batch_loader_val.get_batch() inputs = to_cuda(to_variable(inputs, volatile=True), GPU) inputs_prev = to_cuda(to_variable(inputs_prev, volatile=True), GPU) outputs_ae_gt = to_cuda( to_variable(outputs_ae_gt, volatile=True), GPU) outputs_multiactions_gt = to_cuda( to_variable(outputs_multiactions_gt, volatile=True), GPU) outputs_flow_gt = to_cuda( to_variable(outputs_flow_gt, volatile=True), GPU) outputs_canny_gt = to_cuda( to_variable(outputs_canny_gt, volatile=True), GPU) outputs_flipped_gt = to_cuda( to_variable(outputs_flipped_gt, volatile=True), GPU) time_cbatch_end = time.time() time_fwbw_start = time.time() (outputs_ae_pred, outputs_grids_pred, outputs_atts_pred, outputs_multiactions_pred, outputs_flow_pred, outputs_canny_pred, outputs_flipped_pred, emb) = predictor(inputs, inputs_prev) outputs_grids_gt = remove_unannotated_grids_gt( outputs_grids_pred, outputs_grids_gt_orig, grids_annotated) outputs_grids_gt = to_cuda( to_variable(outputs_grids_gt, volatile=True), GPU) outputs_atts_gt = remove_unannotated_atts_gt( outputs_atts_pred, outputs_atts_gt_orig, atts_annotated) outputs_atts_gt = to_cuda( to_variable(outputs_atts_gt, volatile=True), GPU) loss_ae = criterion_ae(outputs_ae_pred, outputs_ae_gt) loss_grids = criterion_grids(outputs_grids_pred, outputs_grids_gt) loss_atts = criterion_atts(outputs_atts_pred, outputs_atts_gt) loss_multiactions = criterion_multiactions( outputs_multiactions_pred, outputs_multiactions_gt) loss_flow = criterion_flow(outputs_flow_pred, outputs_flow_gt) loss_canny = criterion_canny(outputs_canny_pred, outputs_canny_gt) loss_flipped = criterion_flipped(outputs_flipped_pred, outputs_flipped_gt) time_fwbw_end = time.time() loss_ae_value = to_numpy(loss_ae)[0] loss_grids_value = to_numpy(loss_grids)[0] loss_atts_value = to_numpy(loss_atts)[0] loss_multiactions_value = to_numpy(loss_multiactions)[0] loss_flow_value = to_numpy(loss_flow)[0] loss_canny_value = to_numpy(loss_canny)[0] loss_flipped_value = to_numpy(loss_flipped)[0] loss_ae_total += loss_ae_value loss_grids_total += loss_grids_value loss_atts_total += loss_atts_value loss_multiactions_total += loss_multiactions_value loss_flow_total += loss_flow_value loss_canny_total += loss_canny_value loss_flipped_total += loss_flipped_value print( "[V] Batch %05d L[ae=%.4f, grids=%.4f, atts=%.4f, multiactions=%.4f, flow=%.4f, canny=%.4f, flipped=%.4f] T[cbatch=%.04fs, fwbw=%.04fs]" % (batch_idx, loss_ae_value, loss_grids_value, loss_atts_value, loss_multiactions_value, loss_flow_value, loss_canny_value, loss_flipped_value, time_cbatch_end - time_cbatch_start, time_fwbw_end - time_fwbw_start)) if i == 0: debug_img = generate_debug_image( inputs, inputs_prev, outputs_ae_gt, outputs_grids_gt_orig, outputs_atts_gt_orig, outputs_multiactions_gt, outputs_flow_gt, outputs_canny_gt, outputs_flipped_gt, outputs_ae_pred, outputs_grids_pred, outputs_atts_pred, outputs_multiactions_pred, outputs_flow_pred, outputs_canny_pred, outputs_flipped_pred, grids_annotated, atts_annotated) misc.imsave( "train_semisupervised_debug_img_val%s.jpg" % ("_withshortcuts" if args.withshortcuts else "", ), debug_img) history.add_value("loss-ae", "val", batch_idx, loss_ae_total / NB_VAL_BATCHES) history.add_value("loss-grids", "val", batch_idx, loss_grids_total / NB_VAL_BATCHES) history.add_value("loss-atts", "val", batch_idx, loss_atts_total / NB_VAL_BATCHES) history.add_value("loss-multiactions", "val", batch_idx, loss_multiactions_total / NB_VAL_BATCHES) history.add_value("loss-flow", "val", batch_idx, loss_flow_total / NB_VAL_BATCHES) history.add_value("loss-canny", "val", batch_idx, loss_canny_total / NB_VAL_BATCHES) history.add_value("loss-flipped", "val", batch_idx, loss_flipped_total / NB_VAL_BATCHES) predictor.train() # generate loss plot if (batch_idx + 1) % PLOT_EVERY == 0: loss_plotter.plot(history) # every N batches, save a checkpoint if (batch_idx + 1) % SAVE_EVERY == 0: checkpoint_fp = "train_semisupervised_model%s.tar" % ( "_withshortcuts" if args.withshortcuts else "", ) { "batch_idx": batch_idx, "history": history.to_string(), "predictor_state_dict": predictor.state_dict(), }, checkpoint_fp) # refresh automatically generated examples (autoencoder, canny edge stuff etc.) if (batch_idx + 1) % 1000 == 0: print("Refreshing autogen dataset...") batch_loader_train.join() examples_autogen_train = load_dataset_autogen( val=False, nb_load=NB_AUTOGEN_TRAIN, not_in=examples_annotated_ids) batch_loader_train = BatchLoader(examples_train, examples_autogen_train, augseq=augseq, queue_size=15, nb_workers=4, threaded=False)
def forward_image(self, subimg, softmax=False, volatile=False, requires_grad=True, gpu=GPU): subimg = np.float32([subimg/255]).transpose((0, 3, 1, 2)) subimg = to_cuda(to_variable(subimg, volatile=volatile, requires_grad=requires_grad), GPU) return self.forward(subimg, softmax=softmax)
def forward(self, inputs, inputs_prev, only_embed=False): def act(x): return F.relu(x, inplace=True) def lrelu(x, negative_slope=0.2): return F.leaky_relu(x, negative_slope=negative_slope, inplace=True) def up(x, f=2): m = nn.UpsamplingNearest2d(scale_factor=f) return m(x) def maxp(x): return F.max_pool2d(x, 2) B = inputs.size(0) pos_x = np.tile( np.linspace(0, 1, 40).astype(np.float32).reshape(1, 1, 40), (B, 1, 23, 1)) pos_x = np.concatenate([pos_x, np.fliplr(pos_x)], axis=1) pos_y = np.tile( np.linspace(0, 1, 23).astype(np.float32).reshape(1, 23, 1), (B, 1, 1, 40)) pos_y = np.concatenate([pos_y, np.flipud(pos_y)], axis=1) pos_x = to_cuda( to_variable(pos_x, volatile=inputs.volatile, requires_grad=inputs.requires_grad), Config.GPU) pos_y = to_cuda( to_variable(pos_y, volatile=inputs.volatile, requires_grad=inputs.requires_grad), Config.GPU) x_emb0_curr = inputs # 3x90x160 x_emb1_curr = lrelu( self.emb_c1_sd_curr( self.emb_c1_bn_curr(self.emb_c1_curr(x_emb0_curr)))) # 45x80 x_emb2_curr = lrelu( self.emb_c2_sd_curr( self.emb_c2_bn_curr(self.emb_c2_curr(x_emb1_curr)))) # 45x80 x_emb2_curr = F.pad(x_emb2_curr, (0, 0, 1, 0)) # 45x80 -> 46x80 x_emb2_curr_pool = maxp(x_emb2_curr) # 23x40 x_emb3_curr = lrelu( self.emb_c3_sd_curr( self.emb_c3_bn_curr( self.emb_c3_curr(x_emb2_curr_pool)))) # 23x40 x_emb0_prev = inputs_prev # 2x45x80 x_emb1_prev = lrelu( self.emb_c1_sd_prev( self.emb_c1_bn_prev(self.emb_c1_prev(x_emb0_prev)))) # 45x80 x_emb1_prev = F.pad(x_emb1_prev, (0, 0, 1, 0)) # 45x80 -> 46x80 x_emb1_prev = maxp(x_emb1_prev) # 23x40 x_emb2_prev = lrelu( self.emb_c2_sd_prev( self.emb_c2_bn_prev(self.emb_c2_prev(x_emb1_prev)))) # 23x40 x_emb3 =[x_emb3_curr, x_emb2_prev, pos_x, pos_y], 1) x_emb3 = F.pad(x_emb3, (0, 0, 1, 0)) # 23x40 -> 24x40 x_emb4 = lrelu(self.emb_c4_sd(self.emb_c4_bn( self.emb_c4(x_emb3)))) # 12x20 x_emb5 = lrelu(self.emb_c5_sd(self.emb_c5_bn( self.emb_c5(x_emb4)))) # 6x10 x_emb6 = lrelu(self.emb_c6_sd(self.emb_c6_bn( self.emb_c6(x_emb5)))) # 3x5 x_emb7 = lrelu(self.emb_c7_sd(self.emb_c7_bn( self.emb_c7(x_emb6)))) # 3x5 x_emb = x_emb7 if only_embed: return x_emb else: x_maps = x_emb # 3x5 x_maps = up(x_maps, 4) # 12x20 x_maps = lrelu( self.maps_c1_bn(self.maps_c1([x_maps, x_emb4], 1)))) # 12x20 x_maps = up(x_maps, 4) # 48x80 x_maps = lrelu( self.maps_c2_bn( self.maps_c2( [x_maps, F.pad(x_emb2_curr, (0, 0, 1, 1))], 1)))) # 48x80 -> 44x80 x_maps = F.pad(x_maps, (0, 0, 1, 0)) # 45x80 x_maps = up(x_maps) # 90x160 x_maps = F.sigmoid(self.maps_c3([x_maps, inputs], 1))) # 90x160 ae_size = 3 + self.nb_previous_images x_grids = x_maps[:, 0:8, ...] x_ae = x_maps[:, 8:8 + ae_size, ...] x_flow = x_maps[:, 8 + ae_size:8 + ae_size + 1, ...] x_canny = x_maps[:, 8 + ae_size + 1:8 + ae_size + 2, ...] x_vec = x_emb x_vec = x_vec.view(-1, 512 * 3 * 5) x_vec = F.dropout(x_vec, p=0.5, x_vec = F.sigmoid(self.vec_fc1(x_vec)) atts_size = 10 + 7 + 3 + 5 + 8 + 4 + 4 + 4 + 3 ma_size = 9 + 9 + 9 + 9 x_atts = x_vec[:, 0:atts_size] x_ma = x_vec[:, atts_size:atts_size + ma_size] x_flipped = x_vec[:, atts_size + ma_size:] return x_ae, x_grids, x_atts, x_ma, x_flow, x_canny, x_flipped, x_emb