def train(args):
    base_dir = args.save_dir
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)
    # logging_config(folder=base_dir, name="training")
    save_movingmnist_cfg(base_dir)

    batch_size = cfg.MODEL.TRAIN.BATCH_SIZE
    max_iterations = cfg.MODEL.TRAIN.MAX_ITER
    test_iteration_interval = cfg.MODEL.VALID_ITER
    test_and_save_checkpoint_iterations = cfg.MODEL.SAVE_ITER
    LR_step_size = cfg.MODEL.TRAIN.LR_STEP
    gamma = cfg.MODEL.TRAIN.GAMMA1

    LR = cfg.MODEL.TRAIN.LR

    criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)
    if cfg.MODEL.TYPE == 'TrajGRU':
        encoder_subnet, encoder_net = encoder_params_mnist()
        forecaster_subnet, forecaster_net = forecaster_params_mnist()
    elif cfg.MODEL.TYPE == 'ConvLSTM':
        encoder_subnet, encoder_net = convlstm_encoder_params_mnist()
        forecaster_subnet, forecaster_net = convlstm_forecaster_params_mnist()
    else:
        raise NotImplementedError('Model is not found.')

    encoder = Encoder(encoder_subnet, encoder_net).to(cfg.GLOBAL.DEVICE)
    forecaster = Forecaster(forecaster_subnet,
                            forecaster_net).to(cfg.GLOBAL.DEVICE)

    encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

    optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=LR_step_size,
                                           gamma=gamma)
    folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1]

    train_mnist(encoder_forecaster, optimizer, criterion, exp_lr_scheduler,
                batch_size, max_iterations, test_iteration_interval,
                test_and_save_checkpoint_iterations, folder_name, base_dir)
import os
from experiments.net_params import convlstm_encoder_params, convlstm_forecaster_params



### Config


batch_size = cfg.GLOBAL.BATCH_SZIE
max_iterations = 100000
test_iteration_interval = 1000
test_and_save_checkpoint_iterations = 1000
LR_step_size = 20000
gamma = 0.7

LR = 1e-4

criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)

encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma)

folder_name = os.path.split(os.path.dirname(os.path.abspath(__file__)))[-1]

train_and_test_milan(encoder_forecaster, optimizer, criterion, exp_lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name)