Ejemplo n.º 1
0
def objective(trial):

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    discriminator_wight = trial.suggest_categorical('discriminator_wight',
                                                    [0, 1])
    optim_name = trial.suggest_categorical('optim_name',
                                           ['Adam', 'Adamax', 'RMSprop'])

    # args.batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])
    args.dropout = trial.suggest_categorical('drop_out', [0, 0.2, 0.5])
    args.batch_norm = trial.suggest_categorical('batch_norm', [0, 1])

    N_TRAIN_EXAMPLES = args.batch_size * 30
    N_VALID_EXAMPLES = args.batch_size * 10

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        use_cuda=args.use_gpu)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(obs_len=args.obs_len,
                                            pred_len=args.pred_len,
                                            embedding_dim=args.embedding_dim,
                                            h_dim=args.encoder_h_dim_d,
                                            mlp_dim=args.mlp_dim,
                                            num_layers=args.num_layers,
                                            dropout=args.dropout,
                                            batch_norm=args.batch_norm,
                                            d_type=args.d_type,
                                            use_cuda=args.use_gpu)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    if optim_name == 'Adam':
        optimizer_g = optim.Adam([{
            'params': generator.parameters(),
            'initial_lr': args.g_learning_rate
        }],
                                 lr=args.g_learning_rate)
        optimizer_d = optim.Adam([{
            'params': discriminator.parameters(),
            'initial_lr': args.d_learning_rate
        }],
                                 lr=args.d_learning_rate)

    elif optim_name == 'Adamax':
        optimizer_g = optim.Adamax([{
            'params': generator.parameters(),
            'initial_lr': args.g_learning_rate
        }],
                                   lr=args.g_learning_rate)
        optimizer_d = optim.Adamax([{
            'params': discriminator.parameters(),
            'initial_lr': args.d_learning_rate
        }],
                                   lr=args.d_learning_rate)
    else:
        optimizer_g = optim.RMSprop([{
            'params': generator.parameters(),
            'initial_lr': args.g_learning_rate
        }],
                                    lr=args.g_learning_rate)
        optimizer_d = optim.RMSprop([{
            'params': discriminator.parameters(),
            'initial_lr': args.d_learning_rate
        }],
                                    lr=args.d_learning_rate)

    scheduler_g = optim.lr_scheduler.StepLR(optimizer_g,
                                            step_size=100,
                                            gamma=0.5,
                                            last_epoch=-1)
    scheduler_d = optim.lr_scheduler.StepLR(optimizer_d,
                                            step_size=100,
                                            gamma=0.5,
                                            last_epoch=-1)

    t, epoch = 0, 0

    while t < 50:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps

        for batch_idx, batch in enumerate(train_loader):

            # Limiting training utils for faster epochs.
            if batch_idx * args.batch_size >= N_TRAIN_EXAMPLES:
                break

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)

                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g, discriminator_wight)

                g_steps_left -= 1

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break

        scheduler_g.step()
        scheduler_d.step()

        metrics_val = check_accuracy(args, val_loader, generator,
                                     discriminator, d_loss_fn,
                                     N_VALID_EXAMPLES)

        ade = metrics_val['ade']

        trial.report(ade, t)

    return ade
Ejemplo n.º 2
0
def main(args):
    global t

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        use_cuda=args.use_gpu)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(obs_len=args.obs_len,
                                            pred_len=args.pred_len,
                                            embedding_dim=args.embedding_dim,
                                            h_dim=args.encoder_h_dim_d,
                                            mlp_dim=args.mlp_dim,
                                            num_layers=args.num_layers,
                                            dropout=args.dropout,
                                            batch_norm=args.batch_norm,
                                            d_type=args.d_type,
                                            use_cuda=args.use_gpu)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    # optimizer_g = optim.Adam([{'params': generator.parameters(), 'initial_lr': args.g_learning_rate}], lr=args.g_learning_rate)
    # optimizer_d = optim.Adam([{'params': discriminator.parameters(), 'initial_lr': args.d_learning_rate}], lr=args.d_learning_rate)

    optimizer_g = optim.Adam(params=generator.parameters(),
                             lr=args.g_learning_rate)
    optimizer_d = optim.Adam(params=discriminator.parameters(),
                             lr=args.d_learning_rate)

    lr_scheduler_g = ReduceLROnPlateau(optimizer_g,
                                       threshold=1e-4,
                                       patience=100,
                                       factor=8e-1,
                                       min_lr=1e-5,
                                       verbose=True)

    lr_scheduler_d = ReduceLROnPlateau(optimizer_d,
                                       threshold=1e-4,
                                       patience=100,
                                       factor=8e-1,
                                       min_lr=1e-5,
                                       verbose=True)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        lr_scheduler_g.load_state_dict(checkpoint['lr_scheduler_g_state'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d_state'])
        # t = checkpoint['counters']['t']
        # epoch = checkpoint['counters']['epoch']
        t, epoch = 0, 0
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint dataset structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'lr_scheduler_g_state': None,
            'lr_scheduler_d_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }

    # scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=1000, gamma=0.5, last_epoch=(epoch if epoch != 0 else -1))
    # scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=1000, gamma=0.5, last_epoch=(epoch if epoch != 0 else -1))

    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps if args.discriminator_weight > 0 else 0
        g_steps_left = args.g_steps
        epoch += 1

        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                if args.discriminator_weight > 0:
                    for k, v in sorted(losses_d.items()):
                        logger.info('  [D] {}: {:.7f}'.format(k, v))
                        checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.7f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args,
                                             val_loader,
                                             generator,
                                             discriminator,
                                             d_loss_fn,
                                             lr_scheduler_g,
                                             lr_scheduler_d,
                                             is_train=False)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               discriminator,
                                               d_loss_fn,
                                               lr_scheduler_g,
                                               lr_scheduler_d,
                                               limit=True,
                                               is_train=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.7f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.7f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['lr_scheduler_g_state'] = lr_scheduler_g.state_dict(
                )
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint['lr_scheduler_d_state'] = lr_scheduler_d.state_dict(
                )
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
Ejemplo n.º 3
0
def main():
    train_path = get_dset_path(DATASET_NAME, 'train')
    val_path = get_dset_path(DATASET_NAME, 'val')
    long_dtype, float_dtype = get_dtypes()

    print("Initializing train dataset")
    train_dset, train_loader = data_loader(train_path)
    print("Initializing val dataset")
    _, val_loader = data_loader(val_path)

    iterations_per_epoch = len(train_dset) / D_STEPS
    NUM_ITERATIONS = int(iterations_per_epoch * NUM_EPOCHS)
    print('There are {} iterations per epoch'.format(iterations_per_epoch))

    generator = TrajectoryGenerator()
    generator.apply(init_weights)
    generator.type(float_dtype).train()
    print('Here is the generator:')
    print(generator)

    discriminator = TrajectoryDiscriminator()
    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    print('Here is the discriminator:')
    print(discriminator)

    optimizer_g = optim.Adam(generator.parameters(), lr=G_LR)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=D_LR)

    t, epoch = 0, 0
    t0 = None
    min_ade = None
    while t < NUM_ITERATIONS:
        gc.collect()
        d_steps_left = D_STEPS
        g_steps_left = G_STEPS
        epoch += 1
        print('Starting epoch {}'.format(epoch))
        for batch in train_loader:

            if d_steps_left > 0:
                losses_d = discriminator_step(batch, generator, discriminator,
                                              gan_d_loss, optimizer_d)
                d_steps_left -= 1
            elif g_steps_left > 0:
                losses_g = generator_step(batch, generator, discriminator,
                                          gan_g_loss, optimizer_g)
                g_steps_left -= 1

            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if t % PRINT_EVERY == 0:
                print('t = {} / {}'.format(t + 1, NUM_ITERATIONS))
                for k, v in sorted(losses_d.items()):
                    print('  [D] {}: {:.3f}'.format(k, v))
                for k, v in sorted(losses_g.items()):
                    print('  [G] {}: {:.3f}'.format(k, v))

                print('Checking stats on val ...')
                metrics_val = check_accuracy(val_loader, generator,
                                             discriminator, gan_d_loss)

                print('Checking stats on train ...')
                metrics_train = check_accuracy(train_loader,
                                               generator,
                                               discriminator,
                                               gan_d_loss,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    print('  [val] {}: {:.3f}'.format(k, v))
                for k, v in sorted(metrics_train.items()):
                    print('  [train] {}: {:.3f}'.format(k, v))

                if min_ade is None or metrics_val['ade'] < min_ade:
                    min_ade = metrics_val['ade']
                    checkpoint = {
                        't': t,
                        'g': generator.state_dict(),
                        'd': discriminator.state_dict(),
                        'g_optim': optimizer_g.state_dict(),
                        'd_optim': optimizer_d.state_dict()
                    }
                    print("Saving checkpoint to model.pt")
                    torch.save(checkpoint, "model.pt")
                    print("Done.")

            t += 1
            d_steps_left = D_STEPS
            g_steps_left = G_STEPS
            if t >= NUM_ITERATIONS:
                break
def main():
    train_metric = 0
    print("Process Started")
    print("Initializing train dataset")
    train_dset, train_loader = data_loader(TRAIN_DATASET_PATH, train_metric)
    print("Initializing val dataset")
    _, val_loader = data_loader(VAL_DATASET_PATH, train_metric)

    iterations_per_epoch = len(train_dset) / BATCH / D_STEPS
    if NUM_EPOCHS:
        NUM_ITERATIONS = int(iterations_per_epoch * NUM_EPOCHS)

        generator = TrajectoryGenerator()

        generator.apply(init_weights)
        if USE_GPU == 0:
            generator.type(torch.FloatTensor).train()
        else:
            generator.type(torch.cuda.FloatTensor).train()
        print('Here is the generator:')
        print(generator)

        discriminator = TrajectoryDiscriminator()

        discriminator.apply(init_weights)
        if USE_GPU == 0:
            discriminator.type(torch.FloatTensor).train()
        else:
            discriminator.type(torch.cuda.FloatTensor).train()
        print('Here is the discriminator:')
        print(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=G_LEARNING_RATE)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=D_LEARNING_RATE)

    t, epoch = 0, 0
    checkpoint = {
        'G_losses': defaultdict(list),
        'D_losses': defaultdict(list),
        'g_state': None,
        'g_optim_state': None,
        'd_state': None,
        'd_optim_state': None,
        'g_best_state': None,
        'd_best_state': None
    }
    ade_list, fde_list, avg_speed_error, f_speed_error = [], [], [], []
    while epoch < NUM_EPOCHS:
        gc.collect()
        d_steps_left, g_steps_left = D_STEPS, G_STEPS
        epoch += 1
        print('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if d_steps_left > 0:
                losses_d = discriminator_step(batch, generator, discriminator,
                                              d_loss_fn, optimizer_d)
                d_steps_left -= 1
            elif g_steps_left > 0:
                losses_g = generator_step(batch, generator, discriminator,
                                          g_loss_fn, optimizer_g)
                g_steps_left -= 1

            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if t > 0 and t % CHECKPOINT_EVERY == 0:

                print('t = {} / {}'.format(t + 1, NUM_ITERATIONS))
                for k, v in sorted(losses_d.items()):
                    print('  [D] {}: {:.3f}'.format(k, v))
                for k, v in sorted(losses_g.items()):
                    print('  [G] {}: {:.3f}'.format(k, v))

                print('Checking stats on val ...')
                metrics_val = check_accuracy(val_loader, generator,
                                             discriminator, d_loss_fn)
                print('Checking stats on train ...')
                metrics_train = check_accuracy(train_loader, generator,
                                               discriminator, d_loss_fn)

                for k, v in sorted(metrics_val.items()):
                    print('  [val] {}: {:.3f}'.format(k, v))
                for k, v in sorted(metrics_train.items()):
                    print('  [train] {}: {:.3f}'.format(k, v))

                ade_list.append(metrics_val['ade'])
                fde_list.append(metrics_val['fde'])
                avg_speed_error.append(metrics_val['msae'])
                f_speed_error.append(metrics_val['fse'])

                if metrics_val.get('ade') == min(
                        ade_list) or metrics_val['ade'] < min(ade_list):
                    print('New low for avg_disp_error')
                if metrics_val.get('fde') == min(
                        fde_list) or metrics_val['fde'] < min(fde_list):
                    print('New low for final_disp_error')
                if metrics_val.get('msae') == min(
                        avg_speed_error
                ) or metrics_val['msae'] < min(avg_speed_error):
                    print('New low for avg_speed_error')
                if metrics_val.get('fse') == min(f_speed_error) or metrics_val[
                        'fse'] < min(f_speed_error):
                    print('New low for final_speed_error')

                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                torch.save(checkpoint, CHECKPOINT_NAME)
                print('Done.')

            t += 1
            d_steps_left = D_STEPS
            g_steps_left = G_STEPS
            if t >= NUM_ITERATIONS:
                break
Ejemplo n.º 5
0
def main():
    train_metric = 0
    if MULTI_CONDITIONAL_MODEL and SINGLE_CONDITIONAL_MODEL:
        raise ValueError("Please select either Multi conditional model or single conditional model flag in constants.py")
    print("Process Started")
    if SINGLE_CONDITIONAL_MODEL:
        train_path = SINGLE_TRAIN_DATASET_PATH
        val_path = SINGLE_VAL_DATASET_PATH
    else:
        train_path = MULTI_TRAIN_DATASET_PATH
        val_path = MULTI_VAL_DATASET_PATH
    print("Initializing train dataset")
    train_dset, train_loader = data_loader(train_path, train_metric, 'train')
    print("Initializing val dataset")
    _, val_loader = data_loader(val_path, train_metric, 'val')

    if MULTI_CONDITIONAL_MODEL:
        iterations_per_epoch = len(train_dset) / BATCH_MULTI_CONDITION / D_STEPS
        NUM_ITERATIONS = int(iterations_per_epoch * NUM_EPOCHS_MULTI_CONDITION)
        generator = TrajectoryGenerator(mlp_dim=MLP_INPUT_DIM_MULTI_CONDITION,
                                        h_dim=H_DIM_GENERATOR_MULTI_CONDITION)
        discriminator = TrajectoryDiscriminator(mlp_dim=MLP_INPUT_DIM_MULTI_CONDITION,
                                                h_dim=H_DIM_DISCRIMINATOR_MULTI_CONDITION)
        speed_regressor = SpeedEncoderDecoder(h_dim=H_DIM_GENERATOR_MULTI_CONDITION)
        required_epoch = NUM_EPOCHS_MULTI_CONDITION

    elif SINGLE_CONDITIONAL_MODEL:
        iterations_per_epoch = len(train_dset) / BATCH_SINGLE_CONDITION / D_STEPS

        NUM_ITERATIONS = int(iterations_per_epoch * NUM_EPOCHS_SINGLE_CONDITION)
        generator = TrajectoryGenerator(mlp_dim=MLP_INPUT_DIM_SINGLE_CONDITION,
                                        h_dim=H_DIM_GENERATOR_SINGLE_CONDITION)
        discriminator = TrajectoryDiscriminator(mlp_dim=MLP_INPUT_DIM_SINGLE_CONDITION,
                                                h_dim=H_DIM_DISCRIMINATOR_SINGLE_CONDITION)
        speed_regressor = SpeedEncoderDecoder(h_dim=H_DIM_GENERATOR_SINGLE_CONDITION)
        required_epoch = NUM_EPOCHS_SINGLE_CONDITION

    print(iterations_per_epoch)
    generator.apply(init_weights)
    if USE_GPU:
        generator.type(torch.cuda.FloatTensor).train()
    else:
        generator.type(torch.FloatTensor).train()
    print('Here is the generator:')
    print(generator)

    discriminator.apply(init_weights)
    if USE_GPU:
        discriminator.type(torch.cuda.FloatTensor).train()
    else:
        discriminator.type(torch.FloatTensor).train()
    print('Here is the discriminator:')
    print(discriminator)

    speed_regressor.apply(init_weights)
    if USE_GPU:
        speed_regressor.type(torch.cuda.FloatTensor).train()
    else:
        speed_regressor.type(torch.FloatTensor).train()
    print('Here is the Speed Regressor:')
    print(speed_regressor)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=G_LEARNING_RATE)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=D_LEARNING_RATE)
    optimizer_speed_regressor = optim.Adam(speed_regressor.parameters(), lr=D_LEARNING_RATE)

    t, epoch = 0, 0
    checkpoint = {
        'G_losses': defaultdict(list),
        'D_losses': defaultdict(list),
        'g_state': None,
        'g_optim_state': None,
        'd_state': None,
        'd_optim_state': None,
        'g_best_state': None,
        'd_best_state': None,
        'best_regressor_state': None,
        'regressor_state': None
    }
    val_ade_list, val_fde_list, train_ade, train_fde, train_avg_speed_error, val_avg_speed_error, val_msae_list = [], [], [], [], [], [], []
    train_ade_list, train_fde_list = [], []

    while epoch < required_epoch:
        gc.collect()
        d_steps_left, g_steps_left, speed_regression_steps_left = D_STEPS, G_STEPS, SR_STEPS
        epoch += 1
        print('Starting epoch {}'.format(epoch))
        disc_loss, gent_loss, sr_loss = [], [], []
        for batch in train_loader:
            if d_steps_left > 0:
                losses_d = discriminator_step(batch, generator, discriminator, d_loss_fn, optimizer_d)
                disc_loss.append(losses_d['D_total_loss'])
                d_steps_left -= 1
            elif g_steps_left > 0:
                losses_g = generator_step(batch, generator, discriminator, g_loss_fn, optimizer_g)
                speed_regression_loss = speed_regressor_step(batch, generator, speed_regressor, optimizer_speed_regressor)
                losses_g['Speed_Regression_Loss'] = speed_regression_loss['Speed_Regression_Loss']
                sr_loss.append(speed_regression_loss['Speed_Regression_Loss'])
                gent_loss.append(losses_g['G_discriminator_loss'])
                g_steps_left -= 1

            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if t > 0 and t % CHECKPOINT_EVERY == 0:

                print('t = {} / {}'.format(t + 1, NUM_ITERATIONS))
                for k, v in sorted(losses_d.items()):
                    print('  [D] {}: {:.3f}'.format(k, v))
                for k, v in sorted(losses_g.items()):
                    print('  [G] {}: {:.3f}'.format(k, v))

                print('Checking stats on val ...')
                metrics_val = check_accuracy(val_loader, generator, discriminator, d_loss_fn, speed_regressor)
                print('Checking stats on train ...')
                metrics_train = check_accuracy(train_loader, generator, discriminator, d_loss_fn, speed_regressor)

                for k, v in sorted(metrics_val.items()):
                    print('  [val] {}: {:.3f}'.format(k, v))
                for k, v in sorted(metrics_train.items()):
                    print('  [train] {}: {:.3f}'.format(k, v))

                val_ade_list.append(metrics_val['ade'])
                val_fde_list.append(metrics_val['fde'])

                train_ade_list.append(metrics_train['ade'])
                train_fde_list.append(metrics_train['fde'])

                if metrics_val.get('ade') == min(val_ade_list) or metrics_val['ade'] < min(val_ade_list) or metrics_val.get('fde') == min(val_fde_list) or metrics_val['fde'] < min(val_fde_list):
                    checkpoint['g_best_state'] = generator.state_dict()
                if metrics_val.get('ade') == min(val_ade_list) or metrics_val['ade'] < min(val_ade_list):
                    print('New low for avg_disp_error')
                    checkpoint['best_g_state'] = generator.state_dict()
                if metrics_val.get('fde') == min(val_fde_list) or metrics_val['fde'] < min(val_fde_list):
                    print('New low for final_disp_error')

                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint['regressor_state'] = speed_regressor.state_dict()
                torch.save(checkpoint, CHECKPOINT_NAME)
                print('Done.')

            t += 1
            d_steps_left = D_STEPS
            g_steps_left = G_STEPS
            if t >= NUM_ITERATIONS:
                break