Exemplo n.º 1
0
    #if read_lut_list == LUT_filelist:
    mean_train_losses.append(np.mean(train_losses))
    mean_valid_losses.append(np.mean(valid_losses))
    #mean_test_losses.append(np.mean(test_losses))
    print('len mean train losses ', len(mean_train_losses))

#test_loader
    for batch_idx, (f_plot, r_plot) in enumerate(test_loader):
        outputs = model(f_plot)
        loss = msre(outputs, radiances)
        #if i % 10 == 0:
            #print(i, outputs)
        filename = ('./plot/03_rtm_nn_' + lat + '_alltoz_epoch_' + str(epoch).zfill(5) +
            '_index_' + str(batch_idx).zfill(5))
        if batch_idx == 0:
            test_plot300(real_epoch, batch_idx, f_plot, wav300, r_plot,
                    outputs, filename, lr)
        if real_epoch % 5 == 0:
            print(batch_idx)
            test_plot300(real_epoch, batch_idx, f_plot, wav300, r_plot,
                    outputs, filename, lr)

        outputs = None
        del outputs

    lossesfile = './result/03_rtm_nn_latL_tozall_mean_losses.txt' 
    print('lossesfile')
    if os.path.exists(lossesfile):
        with open(lossesfile, 'a') as f:
            f.write(str(real_epoch).zfill(5) + ',' + str(np.mean(train_losses))
                    + ',' + str(np.mean(valid_losses))+'\n')
    else:
Exemplo n.º 2
0
# plot radiances
        with torch.no_grad():
            for i, (features, radiances) in enumerate(test_loader):
                outputs = model(features)
                loss = msre(outputs, radiances)
                test_losses.append(loss.item())
                if (i * 128) % (128 * 10) == 0:
                    print(f'{i * 128} / ', len(test_loader)*128, time.time() - timestamp,
                            datetime.datetime.now(), 'test')
                    print(loss.item())
                    timestamp = time.time()

        batch_idx = 0
        filename = ('./plot/01_rtm_nn_' + lat + toz + '_epoch_' + str(real_epoch).zfill(5) +
            '_index_' + str(batch_idx).zfill(5))
        test_plot300(real_epoch, batch_idx, features, wav300, radiances,
                outputs, filename, lr)

        print('test plotting done', filename, lat, toz)
        lossesfile = './result/01_rtm_nn_' + lat + toz + '_mean_losses.txt'
        if os.path.exists(lossesfile):
            with open(lossesfile, 'a') as f:
                f.write(str(real_epoch).zfill(5) + ',' + 
                        str(np.mean(train_losses)) + ',' + 
                        str(np.mean(valid_losses)) + ',' +
                        str(np.mean(test_losses)) + '\n')
        else:
            with open(lossesfile, 'w') as f:
                f.write('index,mean_train_losses,mean_valid_losses,mean_test_losses' + '\n')
                f.write(str(real_epoch).zfill(5) + ',' + 
                        str(np.mean(train_losses)) + ',' + 
                        str(np.mean(valid_losses)) + ',' + 
                outputs = model(features)
                loss = msre(outputs, radiances)
                loss.backward()
                optimizer.step()
                #train_losses.append(loss.item())
                train_losses = np.append(train_losses, loss.item())
                #train_losses.append(loss.data)
                if (i * 128) % (128 * 10) == 0:
                    print(f'{i * 128} / ', len(train_loader)*128, time.time() - timestamp,
                            datetime.datetime.now())
                    print(loss.item())
                    timestamp = time.time()
            if epoch_local % 100 == 0:
                filename = ('./plot/09_rtm_nn_' + lat + '_alltoz_epoch_' + str(epoch).zfill(8) +
                    '_index_' + str(i).zfill(8))
                test_plot300(epoch_local, i, features, wav300, radiances,
                        outputs, filename, lr)



            print('model.eval()')
            model.eval()
            correct = 0
            total = 0

            #with torch.no_grad():
                #for i, (features, radiances) in enumerate(valid_loader):
                    #outputs = model(features)
                    #loss = msre(outputs, radiances)
                    
                    ##valid_losses.append(loss.item())
                    #valid_losses = np.append(valid_losses, loss.item())