def train(args):
    base_dir = args.save_dir
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)
    # logging_config(folder=base_dir, name="training")
    save_movingmnist_cfg(base_dir)

    batch_size = cfg.MODEL.TRAIN.BATCH_SIZE
    max_iterations = cfg.MODEL.TRAIN.MAX_ITER
    test_iteration_interval = cfg.MODEL.VALID_ITER
    test_and_save_checkpoint_iterations = cfg.MODEL.SAVE_ITER
    LR_step_size = cfg.MODEL.TRAIN.LR_STEP
    gamma = cfg.MODEL.TRAIN.GAMMA1

    LR = cfg.MODEL.TRAIN.LR

    criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)
    if cfg.MODEL.TYPE == 'TrajGRU':
        encoder_subnet, encoder_net = encoder_params_mnist()
        forecaster_subnet, forecaster_net = forecaster_params_mnist()
    elif cfg.MODEL.TYPE == 'ConvLSTM':
        encoder_subnet, encoder_net = convlstm_encoder_params_mnist()
        forecaster_subnet, forecaster_net = convlstm_forecaster_params_mnist()
    else:
        raise NotImplementedError('Model is not found.')

    encoder = Encoder(encoder_subnet, encoder_net).to(cfg.GLOBAL.DEVICE)
    forecaster = Forecaster(forecaster_subnet,
                            forecaster_net).to(cfg.GLOBAL.DEVICE)

    encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

    optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=LR_step_size,
                                           gamma=gamma)
    folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1]

    train_mnist(encoder_forecaster, optimizer, criterion, exp_lr_scheduler,
                batch_size, max_iterations, test_iteration_interval,
                test_and_save_checkpoint_iterations, folder_name, base_dir)
Пример #2
0
from torch.optim import lr_scheduler
from nowcasting.models.loss import Weighted_mse_mae
from nowcasting.train_and_test import train_and_test
import joblib
from experiments.net_params import convlstm_encoder_params, convlstm_forecaster_params
import numpy as np
from nowcasting.hko import image
from nowcasting.hko.evaluation import HKOEvaluation
# from nowcasting.helpers.visualization import save_hko_movie

convlstm_encoder = Encoder(convlstm_encoder_params[0],
                           convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)
convlstm_forecaster = Forecaster(convlstm_forecaster_params[0],
                                 convlstm_forecaster_params[1]).to(
                                     cfg.GLOBAL.DEVICE)
convlstm_encoder_forecaster = EF(convlstm_encoder,
                                 convlstm_forecaster).to(cfg.GLOBAL.DEVICE)
# convlstm_encoder_forecaster.load_state_dict(joblib.load('convLSTM_balacned_mse_mae.pkl'))
net = joblib.load('convLSTM_balacned_mse_mae.pkl')
print(net)
batch_size = 2
seq_lenth = 10
height = 100
width = 100
train_data = np.zeros((seq_lenth, batch_size, 1, height, width),
                      dtype=np.uint8)

hit_inds = []
for i in range(0, seq_lenth):
    for j in range(0, batch_size):
        hit_inds.append([i, j])
paths = []
import os
from experiments.net_params import convlstm_encoder_params, convlstm_forecaster_params



### Config


batch_size = cfg.GLOBAL.BATCH_SZIE
max_iterations = 100000
test_iteration_interval = 1000
test_and_save_checkpoint_iterations = 1000
LR_step_size = 20000
gamma = 0.7

LR = 1e-4

criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)

encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma)

folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1]

train_and_test_milan(encoder_forecaster, optimizer, criterion, exp_lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name)
Пример #4
0
### Config

batch_size = cfg.GLOBAL.BATCH_SZIE
max_iterations = 50000
test_iteration_interval = 1000
test_and_save_checkpoint_iterations = 1000

LR = 5e-5
LR_step_size = 20000

encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster = Forecaster(forecaster_params[0],
                        forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)
encoder_forecaster.load_state_dict(
    torch.load(
        '/home/hzzone/save/trajGRU_balanced_mse_mae/models/encoder_forecaster_77000.pth'
    ))

optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                       step_size=LR_step_size,
                                       gamma=0.7)
# mult_step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50000, 80000], gamma=0.1)

criterion = Weighted_mse_mae(LAMBDA=0.01).to(cfg.GLOBAL.DEVICE)

folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1]
Пример #5
0
from nowcasting.models.model import EF
from torch.optim import lr_scheduler
from nowcasting.models.loss import Weighted_mse_mae
from nowcasting.models.trajGRU import TrajGRU
from nowcasting.train_and_test import train_and_test
import numpy as np
from nowcasting.hko.evaluation import *
from experiments.net_params import *
from nowcasting.models.model import Predictor
from nowcasting.helpers.visualization import save_hko_movie
# 使用示例代码,路径不存在,需要进一步进行修改

# 加载模型
encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

encoder_forecaster.load_state_dict(
    torch.load(
        '/home/hzzone/save/trajGRU_frame_weighted_mse/models/encoder_forecaster_45000.pth'
    ))

# 加载数据
hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TEST,
                       sample_mode="random",
                       seq_len=IN_LEN + OUT_LEN)

valid_batch, valid_mask, sample_datetimes, _ = \
    hko_iter.sample(batch_size=1)

valid_batch = valid_batch.astype(np.float32) / 255.0
Пример #6
0
from torch.optim import lr_scheduler
from nowcasting.models.loss import Weighted_mse_mae
from nowcasting.models.trajGRU import TrajGRU
from nowcasting.train_and_test import train_and_test
import numpy as np
from nowcasting.hko.evaluation import *
from experiments.net_params import *
from nowcasting.models.model import Predictor
from nowcasting.helpers.visualization import save_hko_movie

torch.cuda.set_device(0)
in_seq = 5
out_seq = 20
encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

encoder_forecaster.load_state_dict(
    torch.load('encoder_forecaster_45000.pth'
               ))  #save/models/encoder_forecaster_100000.pth'))
torch.save(encoder_forecaster, 'full_encoder_forecaster_45000.pth')

criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)
hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TEST,
                       sample_mode="random",
                       seq_len=in_seq + out_seq)

valid_batch, valid_mask, sample_datetimes, _ = \
    hko_iter.sample(batch_size=1)

Пример #7
0
from nowcasting.models.loss import Weighted_mse_mae
from nowcasting.models.trajGRU import TrajGRU
from nowcasting.train_and_test import train_and_test
import numpy as np
from nowcasting.hko.evaluation import *
import copy
from experiments.net_params import *
from nowcasting.models.model import Predictor
from experiments.rover_and_last_frame import LastFrame, Rover
import time
import pickle


encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
encoder_forecaster1 = EF(encoder, forecaster)
encoder_forecaster2 = copy.deepcopy(encoder_forecaster1)
conv2d_network = Predictor(conv2d_params).to(cfg.GLOBAL.DEVICE)

encoder_forecaster1 = encoder_forecaster1.to(cfg.GLOBAL.DEVICE)
encoder_forecaster2 = encoder_forecaster2.to(cfg.GLOBAL.DEVICE)
# 加载模型
encoder_forecaster1.load_state_dict(torch.load('/home/hzzone/save/trajGRU_balanced_mse_mae/models/encoder_forecaster_77000.pth'))
encoder_forecaster2.load_state_dict(torch.load('/home/hzzone/save/trajGRU_frame_weighted_mse/models/encoder_forecaster_45000.pth'))
conv2d_network.load_state_dict(torch.load('/home/hzzone/save/conv2d/models/encoder_forecaster_60000.pth'))

convlstm_encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)

convlstm_forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

convlstm_encoder_forecaster = EF(convlstm_encoder, convlstm_forecaster).to(cfg.GLOBAL.DEVICE)