Beispiel #1
0
batch_size = cfg.GLOBAL.BATCH_SZIE
max_iterations = 50000
test_iteration_interval = 1000
test_and_save_checkpoint_iterations = 1000

LR = 5e-5
LR_step_size = 20000

encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster = Forecaster(forecaster_params[0],
                        forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)
encoder_forecaster.load_state_dict(
    torch.load(
        '/home/hzzone/save/trajGRU_balanced_mse_mae/models/encoder_forecaster_77000.pth'
    ))

optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                       step_size=LR_step_size,
                                       gamma=0.7)
# mult_step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50000, 80000], gamma=0.1)

criterion = Weighted_mse_mae(LAMBDA=0.01).to(cfg.GLOBAL.DEVICE)

folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1]

train_and_test(encoder_forecaster, optimizer, criterion, exp_lr_scheduler,
               batch_size, max_iterations, test_iteration_interval,
               test_and_save_checkpoint_iterations, folder_name)
Beispiel #2
0
from nowcasting.train_and_test import train_and_test
import numpy as np
from nowcasting.hko.evaluation import *
from experiments.net_params import *
from nowcasting.models.model import Predictor
from nowcasting.helpers.visualization import save_hko_movie

torch.cuda.set_device(0)
in_seq = 5
out_seq = 20
encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

encoder_forecaster.load_state_dict(
    torch.load('encoder_forecaster_45000.pth'
               ))  #save/models/encoder_forecaster_100000.pth'))
torch.save(encoder_forecaster, 'full_encoder_forecaster_45000.pth')

criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)
hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TEST,
                       sample_mode="random",
                       seq_len=in_seq + out_seq)

valid_batch, valid_mask, sample_datetimes, _ = \
    hko_iter.sample(batch_size=1)


def filter(img):
    h, w = img.shape
    for i in range(h):
Beispiel #3
0
from nowcasting.models.trajGRU import TrajGRU
from nowcasting.train_and_test import train_and_test
import numpy as np
from nowcasting.hko.evaluation import *
from experiments.net_params import *
from nowcasting.models.model import Predictor
from nowcasting.helpers.visualization import save_hko_movie
# 使用示例代码,路径不存在,需要进一步进行修改

# 加载模型
encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

encoder_forecaster.load_state_dict(
    torch.load(
        '/home/hzzone/save/trajGRU_frame_weighted_mse/models/encoder_forecaster_45000.pth'
    ))

# 加载数据
hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TEST,
                       sample_mode="random",
                       seq_len=IN_LEN + OUT_LEN)

valid_batch, valid_mask, sample_datetimes, _ = \
    hko_iter.sample(batch_size=1)

valid_batch = valid_batch.astype(np.float32) / 255.0
valid_data = valid_batch[:IN_LEN, ...]
valid_label = valid_batch[IN_LEN:IN_LEN + OUT_LEN, ...]
mask = valid_mask[IN_LEN:IN_LEN + OUT_LEN, ...].astype(int)
torch_valid_data = torch.from_numpy(valid_data).to(cfg.GLOBAL.DEVICE)
Beispiel #4
0
from nowcasting.models.model import Predictor
from experiments.rover_and_last_frame import LastFrame, Rover
import time
import pickle


encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
encoder_forecaster1 = EF(encoder, forecaster)
encoder_forecaster2 = copy.deepcopy(encoder_forecaster1)
conv2d_network = Predictor(conv2d_params).to(cfg.GLOBAL.DEVICE)

encoder_forecaster1 = encoder_forecaster1.to(cfg.GLOBAL.DEVICE)
encoder_forecaster2 = encoder_forecaster2.to(cfg.GLOBAL.DEVICE)
# 加载模型
encoder_forecaster1.load_state_dict(torch.load('/home/hzzone/save/trajGRU_balanced_mse_mae/models/encoder_forecaster_77000.pth'))
encoder_forecaster2.load_state_dict(torch.load('/home/hzzone/save/trajGRU_frame_weighted_mse/models/encoder_forecaster_45000.pth'))
conv2d_network.load_state_dict(torch.load('/home/hzzone/save/conv2d/models/encoder_forecaster_60000.pth'))

convlstm_encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)

convlstm_forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

convlstm_encoder_forecaster = EF(convlstm_encoder, convlstm_forecaster).to(cfg.GLOBAL.DEVICE)
convlstm_encoder_forecaster.load_state_dict(torch.load('/home/hzzone/save/convLSTM_balacned_mse_mae/models/encoder_forecaster_64000.pth'))

models = OrderedDict({
    'convLSTM_balacned_mse_mae': convlstm_encoder_forecaster,
    'conv2d': conv2d_network,
    'trajGRU_balanced_mse_mae': encoder_forecaster1,
    'trajGRU_frame_weighted_mse': encoder_forecaster2,