Esempio n. 1
0
threshold = nn.Threshold(0., 0.0)

params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=0.001)

s = 1

for e in range(100):
    for i, v in enumerate(dloader):
        optimizer.zero_grad()
        images = Variable(v[0].cuda()).view(4, 32, 1, 28, 28)

        ########
        #Encoder
        ########
        hidden = encoder.get_init_states(32)
        _, encoder_state = encoder(images.clone(), hidden)

        ########
        #Decoder
        ########
        #prepare inputs for decoder
        rev_images = flip_var(images, 0)
        dec_input = torch.cat((Variable(torch.zeros(
            1, 32, 1, 28, 28).cuda()), rev_images[:-1, :, :, :, :]), 0)

        decoder_out, _ = decoder(dec_input, encoder_state)

        #######
        #loss##
        #######
Esempio n. 2
0
        seqs = batch
        nextf_raw = seqs[:,10:,:,:,:].cuda()

        #----cnn encoder----

        prevf_raw = seqs[:,:10,:,:,:].contiguous().view(-1,1,64,64).cuda()
        prevf_enc = cnn_encoder(prevf_raw).view(b_size,10,hidden_dim,hidden_spt,hidden_spt)

        if teacher_forcing:
            cnn_encoder_out = cnn_encoder(seqs[:,10:,:,:,:].contiguous().view(-1,1,64,64).cuda())
            nextf_enc       = cnn_encoder_out.view(b_size,10,hidden_dim,hidden_spt,hidden_spt)

        #----lstm encoder---

        hidden           = lstm_encoder.get_init_states(b_size)
        _, encoder_state = lstm_encoder(prevf_enc, hidden)

        #----lstm decoder---

        sample_prob =  get_sample_prob(i) if teacher_forcing else 0
        decoder_output_list = []
        r_hist = []

        for s in range(10):
            if s == 0:
                decoder_out, decoder_state = lstm_decoder(prevf_enc[:,-1:,:,:,:], encoder_state)
            else:
                r = np.random.rand()
                r_hist.append(int(r > sample_prob)) #debug