예제 #1
0
            out.view(3, x_fra, y_fra).unsqueeze(0).unsqueeze(2),
            expectedOut[:, N_FRAME, :].view(3, x_fra,
                                            y_fra).unsqueeze(0).unsqueeze(2),
            1, '')

        # lossR = loss_mse(output[:, FRA, :], expectedOut[:, FRA, :])
        loss = loss_mse(output[:, FRA, :], expectedOut[:, FRA, :]) \
               + loss_mse(output[:, 0:FRA, :], expectedOut[:, 0:FRA, :])/(2*FRA) # probar amb expectedout com a GT

        loss.backward()
        optimizerrgb.step()

        loss_value.append(loss.data.item())
        # loss_valueR.append(lossR.data.item())

    loss_val = np.mean(np.array(loss_value))
    # loss_valR = np.mean(np.array(loss_valueR))

    if epoch % saveEvery == 0 and savecheckpt:
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': modelrgb.state_dict(),
                'optimizer': optimizerrgb.state_dict(),
            }, checkptname_rgb + str(epoch) + '.pth')

    if epoch % 30 == 0:
        print(modelrgb.state_dict()['l1.rr'])
        print(modelrgb.state_dict()['l1.theta'])
    print('Epoch: ', epoch, '| train loss: %.4f' %
          loss_val)  #, '| train loss pred: %.4f' % loss_valR)
예제 #2
0
            #
            # tmp2 = np.zeros([64, 64, 3], dtype=np.float16)
            # tmp2[:, :, 0] = eo[0, FRA, :].reshape(x_fra, y_fra)
            # tmp2[:, :, 1] = eo[1, FRA, :].reshape(x_fra, y_fra)
            #
            # scipy.misc.imsave('expected_outputOF.png', tmp2)

            # torchvision.utils.save_image(output[:, FRA].view(2,x_fra,y_fra), 'predicted_output.png',)
            # torchvision.utils.save_image(expectedOut[:, FRA].view(2,x_fra,y_fra), 'expected_output.png', )

            # Compute loss

        loss = loss_mse(output[:,FRA], expectedOut[:,FRA]) # if Kitti: loss = loss_mse(output, expectedOut)
        loss.backward()
        optimizer.step()
        loss_value.append(loss.data.item())

    loss_val = np.mean(np.array(loss_value))


    if epoch % saveEvery ==0 :
        save_checkpoint({	'epoch': epoch + 1,
                            'state_dict': model.state_dict(),
                            'optimizer' : optimizer.state_dict(),
                            },checkptname+str(epoch)+'.pth')

    # if epoch % 4 == 0:
        # print(model.state_dict()['l1.rr'])
        # print(model.state_dict()['l1.theta'])
        # loss_val = float(loss_val/i_batch)
    print('Epoch: ', epoch, '| train loss: %.4f' % loss_val, '| time per epoch: %.4f' % (time.time()-t0_epoch))