def train(args): base_dir = args.save_dir if not os.path.exists(base_dir): os.mkdir(base_dir) # logging_config(folder=base_dir, name="training") save_movingmnist_cfg(base_dir) batch_size = cfg.MODEL.TRAIN.BATCH_SIZE max_iterations = cfg.MODEL.TRAIN.MAX_ITER test_iteration_interval = cfg.MODEL.VALID_ITER test_and_save_checkpoint_iterations = cfg.MODEL.SAVE_ITER LR_step_size = cfg.MODEL.TRAIN.LR_STEP gamma = cfg.MODEL.TRAIN.GAMMA1 LR = cfg.MODEL.TRAIN.LR criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE) if cfg.MODEL.TYPE == 'TrajGRU': encoder_subnet, encoder_net = encoder_params_mnist() forecaster_subnet, forecaster_net = forecaster_params_mnist() elif cfg.MODEL.TYPE == 'ConvLSTM': encoder_subnet, encoder_net = convlstm_encoder_params_mnist() forecaster_subnet, forecaster_net = convlstm_forecaster_params_mnist() else: raise NotImplementedError('Model is not found.') encoder = Encoder(encoder_subnet, encoder_net).to(cfg.GLOBAL.DEVICE) forecaster = Forecaster(forecaster_subnet, forecaster_net).to(cfg.GLOBAL.DEVICE) encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE) optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma) folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1] train_mnist(encoder_forecaster, optimizer, criterion, exp_lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name, base_dir)
import os from experiments.net_params import convlstm_encoder_params, convlstm_forecaster_params ### Config batch_size = cfg.GLOBAL.BATCH_SZIE max_iterations = 100000 test_iteration_interval = 1000 test_and_save_checkpoint_iterations = 1000 LR_step_size = 20000 gamma = 0.7 LR = 1e-4 criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE) encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE) forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE) encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE) optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma) folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1] train_and_test_milan(encoder_forecaster, optimizer, criterion, exp_lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name)