示例#1
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    long_dtype, float_dtype = get_dtypes(args)

    argoverse_train = Argoverse_Social_Data(
        '../../deep_prediction/data/train/data/')
    argoverse_val = Argoverse_Social_Data(
        '../../deep_prediction/data/val/data')
    argoverse_test = Argoverse_Social_Data(
        '../../deep_prediction/data/test_obs/data')

    train_loader = DataLoader(argoverse_train,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=collate_traj_social)
    val_loader = DataLoader(argoverse_val,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=2,
                            collate_fn=collate_traj_social)
    test_loader = DataLoader(argoverse_test,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             collate_fn=collate_traj_social)

    iterations_per_epoch = len(
        argoverse_train) / args.batch_size / args.d_steps

    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(obs_len=args.obs_len,
                                            pred_len=args.pred_len,
                                            embedding_dim=args.embedding_dim,
                                            h_dim=args.encoder_h_dim_d,
                                            mlp_dim=args.mlp_dim,
                                            num_layers=args.num_layers,
                                            dropout=args.dropout,
                                            batch_norm=args.batch_norm,
                                            d_type=args.d_type)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=args.d_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None

    writer_iter = 0

    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))

        for batch in train_loader:

            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'

                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)

                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1

            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                writer.add_scalar('D_data_loss', losses_d['D_data_loss'], t)
                writer.add_scalar('D_total_loss', losses_d['D_total_loss'], t)
                writer.add_scalar('G_discriminator_loss',
                                  losses_g['G_discriminator_loss'], t)
                writer.add_scalar('G_l2_loss_rel', losses_g['G_l2_loss_rel'],
                                  t)
                writer.add_scalar('G_total_loss', losses_g['G_total_loss'], t)

            ### save: D_losses, G_losses

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args,
                                             val_loader,
                                             generator,
                                             discriminator,
                                             d_loss_fn,
                                             limit=True)

                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               discriminator,
                                               d_loss_fn,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)

                writer.add_scalar('val_ade', metrics_val['ade'], t)
                writer.add_scalar('val_ade_l', metrics_val['ade_l'], t)
                writer.add_scalar('val_ade_nl', metrics_val['ade_nl'], t)
                writer.add_scalar('val_fde', metrics_val['fde'], t)
                writer.add_scalar('val_fde_l', metrics_val['fde_l'], t)
                writer.add_scalar('val_fde_nl', metrics_val['fde_nl'], t)

                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                writer.add_scalar('train_ade', metrics_train['ade'], t)
                writer.add_scalar('train_ade_l', metrics_train['ade_l'], t)
                writer.add_scalar('train_ade_nl', metrics_train['ade_nl'], t)
                writer.add_scalar('train_fde', metrics_train['fde'], t)
                writer.add_scalar('train_fde_l', metrics_train['fde_l'], t)
                writer.add_scalar('train_fde_nl', metrics_train['fde_nl'], t)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    model_save_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    model_save_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps

            #### Writer entries here ####
            #             writer.add_scalar('Loss/train', np.random.random(), n_iter)

            if t >= args.num_iterations:
                writer.close()
                break
def main(args):
    print(args)
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    # train_path = get_dset_path(args.dataset_name, 'train')
    # val_path = get_dset_path(args.dataset_name, 'val')
    train_path = os.path.join(args.dataset_dir, args.dataset_name,
                              'train_sample')  # 10 files:0-9
    val_path = os.path.join(args.dataset_dir, args.dataset_name,
                            'val_sample')  # 5 files: 10-14

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        tp_dropout=args.tp_dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        team_embedding_dim=args.team_embedding_dim,
        pos_embedding_dim=args.pos_embedding_dim,
        interaction_activation=args.interaction_activation)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    generator = generator.cuda()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        tp_dropout=args.tp_dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type,
        activation=args.d_activation,  # default: relu,
        pos_embedding_dim=args.pos_embedding_dim,
        team_embedding_dim=args.team_embedding_dim,
        interaction_activation=args.interaction_activation)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    discriminator = discriminator.cuda()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=args.d_learning_rate)
    scheduler_g = optim.lr_scheduler.MultiStepLR(optimizer_g,
                                                 milestones=[10, 50],
                                                 gamma=args.g_gamma)
    scheduler_d = optim.lr_scheduler.MultiStepLR(optimizer_d,
                                                 milestones=[10, 50],
                                                 gamma=args.d_gamma)
    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        scheduler_g.step()
        scheduler_d.step()

        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    # logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    # logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                ## log scalars
                for k, v in sorted(losses_d.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)
                for k, v in sorted(losses_g.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, generator,
                                             discriminator, d_loss_fn)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               discriminator,
                                               d_loss_fn,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    # logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    # logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                ## log scalars
                for k, v in sorted(metrics_val.items()):
                    writer.add_scalar("val/{}".format(k), v, t)
                for k, v in sorted(metrics_train.items()):
                    writer.add_scalar("train/{}".format(k), v, t)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_with_model_{:06d}.pt'.format(args.checkpoint_name,t)
                # )
                checkpoint_path = os.path.join(
                    args.output_dir,
                    '{}_with_model.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items

                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_no_model_{:06d}.pt' .format(args.checkpoint_name,t))

                checkpoint_path = os.path.join(
                    args.output_dir,
                    '{}_no_model.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
示例#3
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    # train_dset: TrajectoryDataset为我们准备好了所有的数据集
    # train_loader: 用于一批一批加载数据
    train_dset, train_loader = data_loader(args, train_path)
    # print("train_dset.__len__():", train_dset.__len__())
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    # args.batch_size = 64
    # args.d_steps = 2
    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    # args.num_epochs = 200
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info('There are {} iterations per epoch'.format(iterations_per_epoch))

    # generator network
    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    # discriminator network
    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    # loss function from "losses.py"
    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    # optimizer function
    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=args.d_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir, '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps  # args.d_steps = 2 discriminator steps
        g_steps_left = args.g_steps  # args.g_steps = 1 generator steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:  # 循环调用DataLoader对象,将数据一批一批的加载到模型进行训练
            # print("batch:", batch)
            # (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, obs_traj_rel_v, pred_traj_rel_v, obs_traj_g, pred_traj_g,
            #  non_linear_ped, loss_mask, seq_start_end) = batch
            # print("obs_traj:", obs_traj.shape)
            # print("pred_traj_gt:", pred_traj_gt.shape)
            # print("obs_traj_rel:", obs_traj_rel.shape)
            # print("pred_traj_gt_rel:", pred_traj_gt_rel.shape)
            # print("obs_traj_rel_v:", obs_traj_rel_v.shape)
            # print("pred_traj_rel_v:", pred_traj_rel_v.shape)
            # print("obs_traj_g:", obs_traj_g.shape)
            # print("pred_traj_g:", pred_traj_g.shape)
            # print("non_linear_ped:", non_linear_ped.shape)
            # print("loss_mask:", loss_mask.shape)

            # (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch
            # print("obs_traj:", obs_traj.shape)
            # print("pred_traj_gt:", pred_traj_gt.shape)
            # print("obs_traj_rel:", obs_traj_rel.shape)
            # print("pred_traj_gt_rel:", pred_traj_gt_rel.shape)
            # print("non_linear_ped:", non_linear_ped.shape)
            # print("loss_mask:", loss_mask.shape)

            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                # 判别器
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator, discriminator, d_loss_fn, optimizer_d)
                checkpoint['norm_d'].append(get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                # 生成器
                step_type = 'g'
                losses_g = generator_step(args, batch, generator, discriminator, g_loss_fn, optimizer_g)
                checkpoint['norm_g'].append(get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(t - 1, time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, generator, discriminator, d_loss_fn)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args, train_loader, generator, discriminator, d_loss_fn, limit=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    #print(train_path)
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)
    
    #print(len(train_loader)) 46

    #iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps # len(train_dset) = 2930 useful frames:[(0, 2), (2, 4), (4, 6),...,(33860, 33866)]
    iterations_per_epoch = len(train_dset) // args.batch_size // args.d_steps + 1
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch)
    )

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generator.apply(init_weights)#不知道m是什么
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss #??????????
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(
        discriminator.parameters(), lr=args.d_learning_rate
    )

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
        #added by cuimingbo
        args.num_epochs = checkpoint['args']['num_epochs']
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'metrics_val_epoch': defaultdict(list),
            'metrics_train_epoch': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'g_best_state_epoch': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'best_epoch': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
            'best_epoch_nl': None,
        }
    t0 = None
    #while t < args.num_iterations: # args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    args.num_epochs = checkpoint['args']['num_epochs']
    args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    while epoch < args.num_epochs*8: 
        gc.collect()
        d_steps_left = args.d_steps #2训练两次discriminator再训练一次generator
        g_steps_left = args.g_steps #1
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            #print(333333333)
            #print(len(batch))
            #print(np.shape(batch[0])) torch.Size([8, 764, 2])
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters())
                )
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1, time.time() - t0
                    ))
                t0 = time.time()

            # Maybe save loss,?????? choose the maximum loss?
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)
                
            #added by cuimingbo
            if t % iterations_per_epoch == 0: # args.num_iterations = int(iterations_per_epoch * args.num_epochs)
                metrics_val_epoch = check_accuracy(
                    args, val_loader, generator, discriminator, d_loss_fn
                )
                #logger.info('Checking stats on train ...')
                metrics_train_epoch = check_accuracy(
                    args, train_loader, generator, discriminator,
                    d_loss_fn, limit=True
                )

                for k, v in sorted(metrics_val_epoch.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val_epoch'][k].append(v)
                for k, v in sorted(metrics_train_epoch.items()):#based on time to sort?
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train_epoch'][k].append(v)
                    

                min_ade = min(checkpoint['metrics_val_epoch']['ade'])
                min_ade_nl = min(checkpoint['metrics_val_epoch']['ade_nl'])

                
                if metrics_val_epoch['ade'] == min_ade:
                    #logger.info('New low for avg_disp_error')
                    checkpoint['best_epoch'] = epoch
                    logger.info('optimal epochs for ade: {}'.format(t // iterations_per_epoch))
                    checkpoint['g_best_state_epoch'] = generator.state_dict()#???????? state_dict
                    #checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val_epoch['ade_nl'] == min_ade_nl:
                    #logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_epoch_nl'] = epoch
                    logger.info('optimal epochs for ade_nl: {}'.format(t // iterations_per_epoch))
                    #checkpoint['g_best_nl_state'] = generator.state_dict()
                    #checkpoint['d_best_nl_state'] = discriminator.state_dict()

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(
                    args, val_loader, generator, discriminator, d_loss_fn
                )
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(
                    args, train_loader, generator, discriminator,
                    d_loss_fn, limit=True
                )

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):#based on time to sort?
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()#???????? state_dict
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name
                )
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            #print(t)
            if t >= args.num_iterations*4:
                break
示例#5
0
文件: train.py 项目: ACoTAI/CODE
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')
    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)
    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch)
    )
    generatorSO = TrajectoryGeneratorSO(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generatorSO.apply(init_weights)
    generatorSO.type(float_dtype).train()
    logger.info('Here is the generatorSO:')
    logger.info(generatorSO)
    #TODO:generator step two
    generatorST = TrajectoryGeneratorST(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)
    generatorST.apply(init_weights)
    generatorST.type(float_dtype).train()
    logger.info('Here is the generatorST:')
    logger.info(generatorST)
    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type)
    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)
    netH = StatisticsNetwork(z_dim = 2*args.noise_dim[0] + 4*args.pred_len, dim=512)
    netH.apply(init_weights)
    netH.type(float_dtype).train()
    logger.info('Here is the netH:')
    logger.info(netH)
    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss
    optimizer_gso = optim.Adam(generatorSO.parameters(), lr=args.g_learning_rate)
    optimizer_gst = optim.Adam(generatorST.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=args.d_learning_rate)
    optimizer_h = optim.Adam(netH.parameters(), lr=args.h_learning_rate)
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generatorSO.load_state_dict(checkpoint['gso_state'])
        generatorST.load_state_dict(checkpoint['gst_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        #TODO:gso&gst
        optimizer_gso.load_state_dict(checkpoint['gso_optim_state'])
        optimizer_gst.load_state_dict(checkpoint['gst_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_gso': [],
            'norm_gst': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            #TODO:gso&gst
            'gso_state': None,
            'gst_state': None,
            'gso_optim_state': None,
            'gst_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'gso_best_state': None,
            'gst_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'gso_best_nl_state': None,
            'gst_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time().
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generatorSO, generatorST,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generatorSO, generatorST,
                                          discriminator, netH, g_loss_fn,
                                          optimizer_gso, optimizer_gst, optimizer_h)
                checkpoint['norm_gso'].append(
                    get_total_norm(generatorSO.parameters())
                )
                checkpoint['norm_gst'].append(
                    get_total_norm(generatorST.parameters())
                )
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1, time.time() - t0
                    ))
                t0 = time.time()
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(
                    args, val_loader, generatorSO, generatorST, discriminator, d_loss_fn
                )
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(
                    args, train_loader, generatorSO, generatorST, discriminator,
                    d_loss_fn, limit=True
                )
                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)
                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])
                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['gso_best_state'] = generatorSO.state_dict()
                    checkpoint['gst_best_state'] = generatorST.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()
                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['gso_best_nl_state'] = generatorSO.state_dict()
                    checkpoint['gst_best_nl_state'] = generatorST.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()
                checkpoint['gso_state'] = generatorSO.state_dict()
                checkpoint['gst_state'] = generatorST.state_dict()

                checkpoint['gso_optim_state'] = optimizer_gso.state_dict()
                checkpoint['gst_optim_state'] = optimizer_gst.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name
                )
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                #TODO:gso&gst
                key_blacklist = [
                    'gso_state', 'gst_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'gso_optim_state', 'gst_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
示例#6
0
def main(args):
    print(args)
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    # train_path = get_dset_path(args.dataset_name, 'train')
    # val_path = get_dset_path(args.dataset_name, 'val')
    train_path = os.path.join(args.dataset_dir, args.dataset_name,
                              'train_sample')  # 10 files:0-9
    val_path = os.path.join(args.dataset_dir, args.dataset_name,
                            'val_sample')  # 5 files: 10-14

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    regressor = TrajectoryLinearRegressor(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        mlp_dim=args.mlp_dim,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
    )

    regressor.apply(init_weights)
    regressor.type(float_dtype).train()
    logger.info('Here is the regressor:')
    logger.info(regressor)

    optimizer_r = optim.Adam(regressor.parameters(), lr=args.g_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        regressor.load_state_dict(checkpoint['r_state'])
        optimizer_r.load_state_dict(checkpoint['r_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'R_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_r': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'r_state': None,
            'r_optim_state': None,
            'r_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.

            step_type = 'r'
            losses_r = regressor_step(args, batch, regressor, optimizer_r)
            checkpoint['norm_r'].append(get_total_norm(regressor.parameters()))

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_r.items()):
                    logger.info('  [R] {}: {:.3f}'.format(k, v))
                    checkpoint['R_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                ## log scalars
                for k, v in sorted(losses_r.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, regressor)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               regressor,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                ## log scalars
                for k, v in sorted(metrics_val.items()):
                    writer.add_scalar("val/{}".format(k), v, t)
                for k, v in sorted(metrics_train.items()):
                    writer.add_scalar("train/{}".format(k), v, t)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['r_best_state'] = regressor.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['r_best_nl_state'] = regressor.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['r_state'] = regressor.state_dict()
                checkpoint['r_optim_state'] = regressor.state_dict()

                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_with_model_{:06d}.pt'.format(args.checkpoint_name,t)
                # )
                checkpoint_path = os.path.join(
                    args.output_dir,
                    '{}_with_model.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items

                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_no_model_{:06d}.pt' .format(args.checkpoint_name,t))

                checkpoint_path = os.path.join(
                    args.output_dir,
                    '{}_no_model.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'r_state',
                    'r_best_state',
                    'r_best_nl_state',
                    'r_optim_state',
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            if t >= args.num_iterations:
                break
示例#7
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    # if args.num_epochs:
    #     args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    global TrajectoryGenerator, TrajectoryDiscriminator
    if args.GAN_type == 'rnn':
        print("Default Social GAN")
        from sgan.models import TrajectoryGenerator, TrajectoryDiscriminator
    elif args.GAN_type == 'simple_rnn':
        print("Default Social GAN")
        from sgan.rnn_models import TrajectoryGenerator, TrajectoryDiscriminator
    else:
        print("Feedforward GAN")
        from sgan.ffd_models import TrajectoryGenerator, TrajectoryDiscriminator

    if args.GAN_type == 'ff':
        generator = TrajectoryGenerator(
            obs_len=args.obs_len,
            pred_len=args.pred_len,
            embedding_dim=args.embedding_dim,
            encoder_h_dim=args.encoder_h_dim_g,
            decoder_h_dim=args.decoder_h_dim_g,
            rep_dim=args.rep_dim,
            mlp_dim=args.mlp_dim,
            encoder_num_layers=args.encoder_num_layers,
            decoder_num_layers=args.decoder_num_layers,
            noise_dim=args.noise_dim,
            noise_type=args.noise_type,
            noise_mix_type=args.noise_mix_type,
            pooling_type=args.pooling_type,
            pool_every_timestep=args.pool_every_timestep,
            dropout=args.dropout,
            bottleneck_dim=args.bottleneck_dim,
            neighborhood_size=args.neighborhood_size,
            grid_size=args.grid_size,
            batch_norm=args.batch_norm,
            pos_embed=args.pos_embed,
            pos_embed_freq=args.pos_embed_freq,
        )

        discriminator = TrajectoryDiscriminator(
            obs_len=args.obs_len,
            pred_len=args.pred_len,
            embedding_dim=args.embedding_dim,
            h_dim=args.encoder_h_dim_d,
            mlp_dim=args.mlp_dim,
            num_layers=args.discrim_num_layers,
            dropout=args.dropout,
            batch_norm=args.batch_norm,
            d_type=args.d_type)

    else:
        generator = TrajectoryGenerator(
            obs_len=args.obs_len,
            pred_len=args.pred_len,
            embedding_dim=args.embedding_dim,
            encoder_h_dim=args.encoder_h_dim_g,
            decoder_h_dim=args.decoder_h_dim_g,
            mlp_dim=args.mlp_dim,
            num_layers=args.num_layers,
            noise_dim=args.noise_dim,
            noise_type=args.noise_type,
            noise_mix_type=args.noise_mix_type,
            pooling_type=args.pooling_type,
            pool_every_timestep=args.pool_every_timestep,
            dropout=args.dropout,
            bottleneck_dim=args.bottleneck_dim,
            neighborhood_size=args.neighborhood_size,
            grid_size=args.grid_size,
            batch_norm=args.batch_norm)

        discriminator = TrajectoryDiscriminator(
            obs_len=args.obs_len,
            pred_len=args.pred_len,
            embedding_dim=args.embedding_dim,
            h_dim=args.encoder_h_dim_d,
            mlp_dim=args.mlp_dim,
            num_layers=args.num_layers,
            dropout=args.dropout,
            batch_norm=args.batch_norm,
            d_type=args.d_type)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=args.d_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    fig = plt.figure()
    ax = fig.add_axes([0.1, 0.1, 0.75, 0.75])

    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:

            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            if args.controlled_expt:
                if t % 10 == 0:
                    save = False
                    # if t == 160:
                    # save = True
                    # print(t)
                    plot_trajectory(fig, ax, args, val_loader, generator, save)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, generator,
                                             discriminator, d_loss_fn)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               discriminator,
                                               d_loss_fn,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                if metrics_val['ade'] == min_ade:
                    # Save another checkpoint with model weights and
                    # optimizer state
                    checkpoint['g_state'] = generator.state_dict()
                    checkpoint['g_optim_state'] = optimizer_g.state_dict()
                    checkpoint['d_state'] = discriminator.state_dict()
                    checkpoint['d_optim_state'] = optimizer_d.state_dict()
                    checkpoint_path = os.path.join(
                        args.output_dir,
                        '%s_with_model.pt' % args.checkpoint_name)
                    logger.info(
                        'Saving checkpoint to {}'.format(checkpoint_path))
                    torch.save(checkpoint, checkpoint_path)
                    logger.info('Done.')

                    # Save a checkpoint with no model weights by making a shallow
                    # copy of the checkpoint excluding some items
                    checkpoint_path = os.path.join(
                        args.output_dir,
                        '%s_no_model.pt' % args.checkpoint_name)
                    logger.info(
                        'Saving checkpoint to {}'.format(checkpoint_path))
                    key_blacklist = [
                        'g_state', 'd_state', 'g_best_state',
                        'g_best_nl_state', 'g_optim_state', 'd_optim_state',
                        'd_best_state', 'd_best_nl_state'
                    ]
                    small_checkpoint = {}
                    for k, v in checkpoint.items():
                        if k not in key_blacklist:
                            small_checkpoint[k] = v
                    torch.save(small_checkpoint, checkpoint_path)
                    logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
示例#8
0
def main(args):
    logdir = "tensorboard/" + args.dataset_name + "/" + str(
        args.num_epochs) + "_epoch_" + str(args.g_learning_rate) + "_lr"
    if os.path.exists(logdir):
        for file in os.listdir(logdir):
            os.unlink(os.path.join(logdir, file))

    writer = SummaryWriter(logdir)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')
    checkpoint_path = os.path.join(args.output_dir,
                                   '%s_with_model.pt' % args.checkpoint_name)
    pathlib.Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)

    long_dtype, float_dtype = get_dtypes(args)

    if args.moving_threshold:
        generator = TrajEstimatorThreshold(obs_len=args.obs_len,
                                           pred_len=args.pred_len,
                                           embedding_dim=args.embedding_dim,
                                           encoder_h_dim=args.encoder_h_dim_g,
                                           num_layers=args.num_layers,
                                           dropout=args.dropout)
    else:
        generator = TrajEstimator(obs_len=args.obs_len,
                                  pred_len=args.pred_len,
                                  embedding_dim=args.embedding_dim,
                                  encoder_h_dim=args.encoder_h_dim_g,
                                  num_layers=args.num_layers,
                                  dropout=args.dropout)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    generator.train()

    logger.info('Here is the generator:')
    logger.info(generator)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size
    args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    # log 100 points
    log_tensorboard_every = int(args.num_iterations * 0.01) - 1
    if log_tensorboard_every <= 0:
        #there are less than 100 iterations
        log_tensorboard_every = int(args.num_iterations) / 4

    logger.info('There are {} iterations per epoch'.format(
        int(iterations_per_epoch)))

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'g_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'best_t_nl': None,
        }
    while t < args.num_iterations:
        gc.collect()

        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            losses_g = generator_step(args, batch, generator, optimizer_g,
                                      epoch)
            checkpoint['norm_g'].append(get_total_norm(generator.parameters()))

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            # Maybe save values for tensorboard
            if t % log_tensorboard_every == 0:
                for k, v in sorted(losses_g.items()):
                    writer.add_scalar(k, v, t)

                metrics_val = check_accuracy(args, val_loader, generator,
                                             epoch)
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               epoch,
                                               limit=True)
                to_keep = ["g_l2_loss_rel", "ade", "fde"]
                for k, v in sorted(metrics_val.items()):
                    if k in to_keep:
                        writer.add_scalar("val_" + k, v, t)

                for k, v in sorted(metrics_train.items()):
                    if k in to_keep:
                        writer.add_scalar("train_" + k, v, t)

            # Maybe save a checkpoint
            if t % args.checkpoint_every == 0 and t > 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, generator,
                                             epoch)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               epoch,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            if t >= args.num_iterations:
                if args.moving_threshold:
                    logger.info(
                        "Non-moving trajectories : {}%, threshold : {}".format(
                            round(
                                (float(generator.total_trajs_under_threshold) /
                                 generator.total_trajs) * 100),
                            generator.threshold))
                break
示例#9
0
def main(args):
    print(args)
    args.checkpoint_name = args.checkpoint_name + "_" + random_str
    tensorboard_name = "_".join(
        [args.checkpoint_name, args.dataset_name, time_str, hostname])
    dataset_runs_dir = os.path.join(args.output_dir, "runs", args.dataset_name)
    dataset_ckpt_dir = os.path.join(args.output_dir, "checkpoints",
                                    args.dataset_name)
    if not os.path.exists(dataset_runs_dir):
        os.mkdir(dataset_runs_dir)
    if not os.path.exists(dataset_ckpt_dir):
        os.mkdir(dataset_ckpt_dir)

    tensorboard_path = os.path.join(dataset_runs_dir, tensorboard_name)
    writer = SummaryWriter(tensorboard_path)
    log_path = "{}/config.txt".format(tensorboard_path)
    with open(log_path, "a") as f:
        json.dump(args.__dict__, f, indent=2)
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    schema_path = "../sgan/data/configs/{}.json".format(args.schema)
    schema = load_schema(schema_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    # train_path = get_dset_path(args.dataset_name, 'train')
    # val_path = get_dset_path(args.dataset_name, 'val')
    train_path = os.path.join(args.dataset_dir, args.dataset_name,
                              'train_sample')  # 10 files:0-9
    val_path = os.path.join(args.dataset_dir, args.dataset_name,
                            'val_sample')  # 5 files: 10-14

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path, schema)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path, schema)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    generator, discriminator = build_models(args, schema, args.model)

    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g, optimizer_d = build_optimizers(args, generator, discriminator)
    scheduler_g, scheduler_d = build_schedulers(args, optimizer_g, optimizer_d)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(dataset_ckpt_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator, discriminator, optimizer_g, optimizer_d, t, epoch = \
            restore_from_checkpoint(checkpoint, generator, discriminator, optimizer_g, optimizer_d)
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = initialize_checkpoint(args)
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        scheduler_g.step()
        scheduler_d.step()
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    # logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    # logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                ## log scalars
                for k, v in sorted(losses_d.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)
                for k, v in sorted(losses_g.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, generator,
                                             discriminator, d_loss_fn)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               discriminator,
                                               d_loss_fn,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    # logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    # logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                ## log scalars
                for k, v in sorted(metrics_val.items()):
                    writer.add_scalar("val/{}".format(k), v, t)
                for k, v in sorted(metrics_train.items()):
                    writer.add_scalar("train/{}".format(k), v, t)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_with_model_{:06d}.pt'.format(args.checkpoint_name,t)
                # )
                checkpoint_path = os.path.join(
                    dataset_ckpt_dir,
                    '{}_with_model.pt'.format(args.checkpoint_name))
                backup_checkpoint_path = os.path.join(
                    dataset_ckpt_dir,
                    '{}_with_model_backup.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                torch.save(checkpoint, backup_checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items

                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_no_model_{:06d}.pt' .format(args.checkpoint_name,t))

                checkpoint_path = os.path.join(
                    dataset_ckpt_dir,
                    '{}_no_model.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break

        ## log scalars
        # for k, v in sorted(losses_d.items()):
        #     writer.add_scalar("train/{}".format(k),v,epoch)
        # for k, v in sorted(losses_g.items()):
        #     writer.add_scalar("train/{}".format(k),v,epoch)
    writer.flush()
def main(args):
    if args.mode == 'training':
        args.checkpoint_every = 100
        args.teacher_name = "default"
        args.restore_from_checkpoint = 0
        #args.l2_loss_weight = 0.0
        args.rollout_steps = 1
        args.rollout_rate = 1
        args.rollout_method = 'sgd'
        #print("HHHH"+str(args.l2_loss_weight))

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    global TrajectoryGenerator, TrajectoryDiscriminator
    if args.GAN_type == 'rnn':
        print("Default Social GAN")
        from sgan.models import TrajectoryGenerator, TrajectoryDiscriminator
    elif args.GAN_type == 'simple_rnn':
        print("Default Social GAN")
        from sgan.rnn_models import TrajectoryGenerator, TrajectoryDiscriminator
    else:
        print("Feedforward GAN")
        if (args.Encoder_type == 'MLP' and args.Decoder_type == 'MLP'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_MLP_D_MLP import TrajectoryGenerator, TrajectoryDiscriminator
        if (args.Encoder_type == 'MLP' and args.Decoder_type == 'LSTM'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_MLP_D_LSTM import TrajectoryGenerator, TrajectoryDiscriminator
        if (args.Encoder_type == 'LSTM' and args.Decoder_type == 'MLP'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_LSTM_D_MLP import TrajectoryGenerator, TrajectoryDiscriminator
        if (args.Encoder_type == 'LSTM' and args.Decoder_type == 'LSTM'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_LSTM_D_LSTM import TrajectoryGenerator, TrajectoryDiscriminator

    #image_dir = 'images/' + 'curve_5_traj_l2_0.5'
    #image_dir = 'images/5trajectory/' + 'havingplots'+ '2-layers-EN-' + args.Encoder_type +  '-DE-20-layers-' + args.Decoder_type + '-L2_' + str(args.l2_loss_weight)

    image_dir = 'images/' + str(args.dataset_name) + \
                '_EN_' + args.Encoder_type + '(' + str(*[args.mlp_encoder_layers if args.Encoder_type == 'MLP' else 1]) + ')' + \
                '_DE_' + args.Decoder_type + '(' + str(*[args.mlp_decoder_layers if args.Decoder_type == 'MLP' else 1]) + ')' + \
                '_DIS_' + args.GAN_type.upper() + '(' + str(args.mlp_discriminator_layers) + ')' + \
                '_L2_Weight' + '(' + str(args.l2_loss_weight) + ')'

    print("Image Dir: ", image_dir)
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        num_mlp_decoder_layers=args.mlp_decoder_layers,
        num_mlp_encoder_layers=args.mlp_encoder_layers)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type,
        mlp_discriminator_layers=args.mlp_discriminator_layers,
        num_mlp_encoder_layers=args.mlp_encoder_layers)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    # build teacher
    print("[!] teacher_name: ", args.teacher_name)

    if args.teacher_name == 'default':
        teacher = None
    elif args.teacher_name == 'gpurollout':
        from teacher_gpu_rollout_torch import TeacherGPURollout
        teacher = TeacherGPURollout(args)
        teacher.set_env(discriminator, generator)
        print("GPU Rollout Teacher")
    else:
        raise NotImplementedError

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=args.d_learning_rate)

    # # Create D optimizer.
    # self.d_optim = tf.train.AdamOptimizer(self.disc_LR*config.D_LR, beta1=config.beta1)
    # # Compute the gradients for a list of variables.
    # self.grads_d_and_vars = self.d_optim.compute_gradients(self.d_loss, var_list=self.d_vars)
    # self.grad_default_real = self.d_optim.compute_gradients(self.d_loss_real, var_list=inputs)
    # # Ask the optimizer to apply the capped gradients.
    # self.update_d = self.d_optim.apply_gradients(self.grads_d_and_vars)
    # ## Get Saliency Map - Teacher
    # self.saliency_map = tf.gradients(self.d_loss, self.inputs)[0]

    # ###### G Optimizer ######
    # # Create G optimizer.
    # self.g_optim = tf.train.AdamOptimizer(config.learning_rate*config.G_LR, beta1=config.beta1)

    # # Compute the gradients for a list of variables.
    # ## With respect to Generator Weights - AutoLoss
    # self.grad_default = self.g_optim.compute_gradients(self.g_loss, var_list=[self.G, self.g_vars])
    # ## With Respect to Images given to D - Teacher
    # # self.grad_default = g_optim.compute_gradients(self.g_loss, var_list=)
    # if config.teacher_name == 'default':
    # self.optimal_grad = self.grad_default[0][0]
    # self.optimal_batch = self.G - self.optimal_grad
    # else:
    # self.optimal_grad, self.optimal_batch = self.teacher.build_teacher(self.G, self.D_, self.grad_default[0][0], self.inputs)

    # # Ask the optimizer to apply the manipulated gradients.
    # grads_collected = tf.gradients(self.G, self.g_vars, self.optimal_grad)
    # grads_and_vars_collected = list(zip(grads_collected, self.g_vars))

    # self.g_teach = self.g_optim.apply_gradients(grads_and_vars_collected)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    fig = plt.figure()
    ax = fig.add_axes([0.1, 0.1, 0.75, 0.75])

    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:

            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:

                if args.mode != 'testing':
                    step_type = 'd'
                    losses_d = discriminator_step(args, batch, generator,
                                                  discriminator, d_loss_fn,
                                                  optimizer_d, teacher,
                                                  args.mode)
                    checkpoint['norm_d'].append(
                        get_total_norm(discriminator.parameters()))

                d_steps_left -= 1

            elif g_steps_left > 0:

                if args.mode != 'testing':
                    step_type = 'g'
                    losses_g = generator_step(args, batch, generator,
                                              discriminator, g_loss_fn,
                                              optimizer_g, args.mode)
                    checkpoint['norm_g'].append(
                        get_total_norm(generator.parameters()))

                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0 and args.mode != 'testing':
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                # # Check stats on the validation set
                # logger.info('Checking stats on val ...')
                # metrics_val = check_accuracy(
                #     args, val_loader, generator, discriminator, d_loss_fn
                # )
                # logger.info('Checking stats on train ...')
                # metrics_train = check_accuracy(
                #     args, train_loader, generator, discriminator,
                #     d_loss_fn, limit=True
                # )

                # for k, v in sorted(metrics_val.items()):
                #     logger.info('  [val] {}: {:.3f}'.format(k, v))
                #     checkpoint['metrics_val'][k].append(v)
                # for k, v in sorted(metrics_train.items()):
                #     logger.info('  [train] {}: {:.3f}'.format(k, v))
                #     checkpoint['metrics_train'][k].append(v)

                # min_ade = min(checkpoint['metrics_val']['ade'])
                # min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                # if metrics_val['ade'] == min_ade:
                #     logger.info('New low for avg_disp_error')
                #     checkpoint['best_t'] = t
                #     checkpoint['g_best_state'] = generator.state_dict()
                #     checkpoint['d_best_state'] = discriminator.state_dict()

                # if metrics_val['ade_nl'] == min_ade_nl:
                #     logger.info('New low for avg_disp_error_nl')
                #     checkpoint['best_t_nl'] = t
                #     checkpoint['g_best_nl_state'] = generator.state_dict()
                #     checkpoint['d_best_nl_state'] = discriminator.state_dict()

            if t % 50 == 0:
                # save = False
                # if t == 160:
                # save = True
                # print(t)
                plot_trajectory(fig,
                                ax,
                                args,
                                val_loader,
                                generator,
                                teacher,
                                args.mode,
                                t,
                                save=True,
                                image_dir=image_dir)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                print("Iteration: ", t)
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break