parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs run (default: 500)')
    parser.add_argument('--batch_size', type=int, default=1, help='multiple of batch size (default: 1)')
    parser.add_argument('--data_dir', type=str, help='location of data', default="train")
    parser.add_argument('--log_dir', type=str, help='location of logging', default="log")
    parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
    parser.add_argument('--data_parallel', type=bool, help='whether to parallelise based on data (default: False)', default=False)
    args = parser.parse_args()

    # Create model and optimizer
    model = GenerativeQueryNetwork(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=12).to(device)
    model = nn.DataParallel(model) if args.data_parallel else model

    optimizer = torch.optim.Adam(model.parameters(), lr=5 * 10 ** (-4))

    # Rate annealing schemes
    sigma_scheme = Annealer(2.0, 0.7, 2 * 10 ** 5)
    mu_scheme = Annealer(5 * 10 ** (-4), 5 * 10 ** (-5), 1.6 * 10 ** 6)

    # Load the dataset
    train_dataset = ShepardMetzler(root_dir=args.data_dir)
    valid_dataset = ShepardMetzler(root_dir=args.data_dir, train=False)

    kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

    def step(engine, batch):
        model.train()

        x, v = batch
        x, v = x.to(device), v.to(device)
    parser.add_argument('--data_dir', type=str, help='location of data', default="train")
    parser.add_argument('--log_dir', type=str, help='location of logging', default="log")
    parser.add_argument('--fraction', type=float, help='how much of the data to use', default=1.0)
    parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
    parser.add_argument('--data_parallel', type=bool, help='whether to parallelise based on data (default: False)', default=False)
    args = parser.parse_args()

    # Create model and optimizer
    #model = GenerativeQueryNetwork(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8).to(device)
    model = GenerativeQueryNetwork(x_dim=3, v_dim=7, r_dim=256, h_dim=64, z_dim=32, L=3).to(device)
    model = nn.DataParallel(model) if args.data_parallel else model

    optimizer = torch.optim.Adam(model.parameters(), lr=5 * 10 ** (-5))

    # Rate annealing schemes
    sigma_scheme = Annealer(2.0, 0.7, 80000)
    mu_scheme = Annealer(5 * 10 ** (-6), 5 * 10 ** (-6), 1.6 * 10 ** 5)

    # Load the dataset
    train_dataset = ShepardMetzler(root_dir=args.data_dir, fraction=args.fraction)
    valid_dataset = ShepardMetzler(root_dir=args.data_dir, fraction=args.fraction, train=False)

    kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

    def step(engine, batch):
        model.train()

        x, v = batch
        x, v = x.to(device), v.to(device)
Beispiel #3
0
    print('Creating GQN Model')
    # Create model and optimizer
    model = GenerativeQueryNetwork(x_dim=3,
                                   v_dim=7,
                                   r_dim=256,
                                   h_dim=128,
                                   z_dim=64,
                                   L=8).to(device)
    model = nn.DataParallel(model) if args.data_parallel else model

    #model = nn.DataParallel(model) if args.data_parallel else model

    optimizer = torch.optim.Adam(model.parameters(), lr=5 * 10**(-5))

    # Rate annealing schemes
    sigma_scheme = Annealer(2.0, 0.7, 80000)
    mu_scheme = Annealer(5 * 10**(-6), 5 * 10**(-6), 1.6 * 10**5)
    print('Creating train dataset')
    # Load the dataset
    train_dataset = CircularOrbit(root_dir=args.data_dir,
                                  fraction=args.fraction,
                                  num_samples=40000)
    print('Creating test dataset')
    valid_dataset = CircularOrbit(root_dir=args.data_dir,
                                  fraction=args.fraction,
                                  num_samples=40000,
                                  train=False)

    kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
    print('train set:', len(train_dataset))
    train_loader = DataLoader(train_dataset,