Ejemplo n.º 1
0
    def load_model(self, path):
        # torch.load最后返回的是一个dict,里面包含了保存模型时的一些参数和模型
        checkpoint = torch.load(path, map_location='cpu')
        self.generator = self.get_generator(checkpoint)
        # AttrDict是根据参数中的dict内容生成一个更加方便访问的dict实例
        self.args = AttrDict(checkpoint['args'])
        train_path = get_dset_path(self.args.dataset_name, "train")
        test_path = get_dset_path(self.args.dataset_name, "test")
        self.args.batch_size = 1
        _, self.loader = data_loader(self.args, train_path)
        _, self.test_loader = data_loader(self.args, test_path)

        self.metrics_val = checkpoint['metrics_val']
        self.metrics_train = checkpoint['metrics_train']
Ejemplo n.º 2
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [
            os.path.join(args.model_path, file_) for file_ in filenames
        ]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde = evaluate(_args, loader, generator, args.num_samples)
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
Ejemplo n.º 3
0
def main(args):
    global global_step
    global_step = 13086

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    print(args)
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    linear = LinearModel(seq_len=args.pred_len, use_cuda=args.use_gpu)

    linear.apply(init_weights)
    linear.type(float_dtype).train()
    logger.info('Here is the linear:')
    logger.info(linear)

    optimizer = optim.Adam(linear.parameters(), lr=args.g_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        linear.load_state_dict(checkpoint['state'])
        optimizer.load_state_dict(checkpoint['optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint dataset structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'state': None,
            'optim_state': None,
            'best_state': None,
            'best_t': None,
            'layer1.weight': None,
            'layer2.weight': None,
            "layer1.bias": None,
            "layer2.bias": None
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            generator_step(args, batch, linear, optimizer)
            # checkpoint['norm_g'].append(
            #     get_total_norm(lstm.parameters())
            # )

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
            #     for k, v in sorted(losses.items()):
            #         logger.info('  [D] {}: {:.7f}'.format(k, v))
            #         checkpoint['losses'][k].append(v)
            #     checkpoint['losses_ts'].append(t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args,
                                             val_loader,
                                             linear,
                                             is_train=False)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               linear,
                                               limit=True,
                                               is_train=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.7f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.7f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['best_state'] = linear.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['state'] = linear.state_dict()
                checkpoint['optim_state'] = optimizer.state_dict()

                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            if t >= args.num_iterations:
                break
Ejemplo n.º 4
0
def main(args):
    global t

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        use_cuda=args.use_gpu)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(obs_len=args.obs_len,
                                            pred_len=args.pred_len,
                                            embedding_dim=args.embedding_dim,
                                            h_dim=args.encoder_h_dim_d,
                                            mlp_dim=args.mlp_dim,
                                            num_layers=args.num_layers,
                                            dropout=args.dropout,
                                            batch_norm=args.batch_norm,
                                            d_type=args.d_type,
                                            use_cuda=args.use_gpu)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    # optimizer_g = optim.Adam([{'params': generator.parameters(), 'initial_lr': args.g_learning_rate}], lr=args.g_learning_rate)
    # optimizer_d = optim.Adam([{'params': discriminator.parameters(), 'initial_lr': args.d_learning_rate}], lr=args.d_learning_rate)

    optimizer_g = optim.Adam(params=generator.parameters(),
                             lr=args.g_learning_rate)
    optimizer_d = optim.Adam(params=discriminator.parameters(),
                             lr=args.d_learning_rate)

    lr_scheduler_g = ReduceLROnPlateau(optimizer_g,
                                       threshold=1e-4,
                                       patience=100,
                                       factor=8e-1,
                                       min_lr=1e-5,
                                       verbose=True)

    lr_scheduler_d = ReduceLROnPlateau(optimizer_d,
                                       threshold=1e-4,
                                       patience=100,
                                       factor=8e-1,
                                       min_lr=1e-5,
                                       verbose=True)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        lr_scheduler_g.load_state_dict(checkpoint['lr_scheduler_g_state'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d_state'])
        # t = checkpoint['counters']['t']
        # epoch = checkpoint['counters']['epoch']
        t, epoch = 0, 0
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint dataset structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'lr_scheduler_g_state': None,
            'lr_scheduler_d_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }

    # scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=1000, gamma=0.5, last_epoch=(epoch if epoch != 0 else -1))
    # scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=1000, gamma=0.5, last_epoch=(epoch if epoch != 0 else -1))

    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps if args.discriminator_weight > 0 else 0
        g_steps_left = args.g_steps
        epoch += 1

        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                if args.discriminator_weight > 0:
                    for k, v in sorted(losses_d.items()):
                        logger.info('  [D] {}: {:.7f}'.format(k, v))
                        checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.7f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args,
                                             val_loader,
                                             generator,
                                             discriminator,
                                             d_loss_fn,
                                             lr_scheduler_g,
                                             lr_scheduler_d,
                                             is_train=False)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               discriminator,
                                               d_loss_fn,
                                               lr_scheduler_g,
                                               lr_scheduler_d,
                                               limit=True,
                                               is_train=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.7f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.7f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['lr_scheduler_g_state'] = lr_scheduler_g.state_dict(
                )
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint['lr_scheduler_d_state'] = lr_scheduler_d.state_dict(
                )
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break