Esempio n. 1
0
def main(exp):

    # Number of examples per batch
    batch_size: Argument & int = default(256)

    # Dataset to load
    dataset: Argument

    torch_settings = init_torch()
    dataset = exp.get_dataset(dataset)

    loader = torch.utils.data.DataLoader(
        dataset.train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=torch_settings.workers,
        pin_memory=True
    )

    wrapper = iteration_wrapper(exp, sync=None)

    # Warm up a bit
    for _, batch in zip(range(10), loader):
        for item in batch:
            item.to(torch_settings.device)
        break

    for it, batch in dataloop(loader, wrapper=wrapper):
        it.set_count(batch_size)
        it.log(eta=True)
        batch = [item.to(torch_settings.device) for item in batch]
        if torch_settings.sync:
            torch_settings.sync()
Esempio n. 2
0
def main(exp):

    # Dataset to use
    dataset: Argument

    # super resolution upscale factor
    upscale_factor: Argument & int = default(2)

    # # testing batch size (default: 10)
    # test_batch_size: Argument & int = default(10)

    # Learning rate (default: 0.1)
    lr: Argument & float = default(0.1)

    # Batch size (default: 64)
    batch_size: Argument & int = default(64)

    torch_settings = init_torch()
    device = torch_settings.device

    print('===> Loading datasets')
    # dataset_instance = exp.resolve_dataset("milabench.presets:bsds500")
    # folder = dataset_instance["environment"]["root"]
    sets = get_dataset(exp, dataset, upscale_factor)
    train_set = sets.train
    # train_set = get_dataset(os.path.join(folder, "bsds500/BSR/BSDS500/data/images/train"), upscale_factor)
    # test_set = get_dataset(os.path.join(folder, "bsds500/BSR/BSDS500/data/images/test"), upscale_factor)

    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=torch_settings.workers,
                                      batch_size=batch_size,
                                      shuffle=True)
    # testing_data_loader = DataLoader(
    #     dataset=test_set,
    #     num_workers=torch_settings.workers,
    #     batch_size=test_batch_size,
    #     shuffle=False
    # )

    print('===> Building model')
    model = Net(upscale_factor=upscale_factor).to(device)
    model.train()
    criterion = nn.MSELoss()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)
    for it, (input, target) in dataloop(training_data_loader, wrapper=wrapper):
        it.set_count(batch_size)

        input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        loss = criterion(model(input), target)
        it.log(loss=loss.item())
        loss.backward()
        optimizer.step()
Esempio n. 3
0
def main(exp):
    # Model float type
    dtype: Argument & str = default("float32")

    # Number of samples
    samples: Argument & int = default(100)

    torch_settings = init_torch()
    device = torch_settings.device

    data = generate_wave_data(20, 1000, samples)

    _dtype = to_type[dtype]

    input = torch.from_numpy(data[3:, :-1]).to(device=device, dtype=_dtype)
    target = torch.from_numpy(data[3:, 1:]).to(device=device, dtype=_dtype)

    test_input = torch.from_numpy(data[:3, :-1]).to(device=device,
                                                    dtype=_dtype)
    test_target = torch.from_numpy(data[:3, 1:]).to(device=device,
                                                    dtype=_dtype)

    # build the model
    seq = Sequence().to(device=device, dtype=_dtype)
    criterion = nn.MSELoss().to(device=device, dtype=_dtype)

    optimizer = optim.SGD(seq.parameters(), lr=0.01)

    total_time = 0

    seq.train()

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)

    for it, _ in dataloop(count(), wrapper=wrapper):
        it.set_count(samples)

        def closure():
            optimizer.zero_grad()
            out = seq(input.to(device=device, dtype=_dtype))
            loss = criterion(out, target)
            loss.backward()
            it.log(loss=loss.item())
            return loss

        optimizer.step(closure)
Esempio n. 4
0
def main(exp):

    # dataset to use
    dataset: Argument & str

    # Number of examples per batch
    batch_size: Argument & int = default(64)

    # path to style-image
    style_image: Argument & str = default(
        os.path.join(repo_base, "neural-style-images/style-images/candy.jpg"))

    # size of training images, default is 256 X 256
    image_size: Argument & int = default(256)

    # size of style-image, default is the original size of style image
    style_size: Argument & int = default(None)

    # weight for content-loss, default is 1e5
    content_weight: Argument & float = default(1e5)

    # weight for style-loss, default is 1e10
    style_weight: Argument & float = default(1e10)

    # learning rate, default is 1e-3
    lr: Argument & float = default(1e-3)

    torch_settings = init_torch()
    device = torch_settings.device

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = exp.get_dataset(dataset, transform).train
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=torch_settings.workers)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    print(
        memory_size(vgg,
                    batch_size=batch_size,
                    input_size=(3, image_size, image_size)) * 4)

    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(style_image, size=style_size)
    style = style_transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)

    transformer.train()

    for it, (x, _) in dataloop(train_loader, wrapper=wrapper):
        it.set_count(len(x))

        n_batch = len(x)

        x = x.to(device)
        y = transformer(x)

        y = utils.normalize_batch(y)
        x = utils.normalize_batch(x)

        optimizer.zero_grad()

        features_y = vgg(y)
        features_x = vgg(x)

        content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                 features_x.relu2_2)

        style_loss = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()

        it.log(loss=total_loss.item())
        optimizer.step()
Esempio n. 5
0
def train300_mlperf_coco(exp, args):

    torch.backends.cudnn.benchmark = True

    device = args.torch_settings.device

    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)

    input_size = 300
    train_trans = SSDTransformer(dboxes, (input_size, input_size), val=False)
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)
    # mlperf_log.ssd_print(key=# mlperf_log.INPUT_SIZE, value=input_size)

    # val_annotate = os.path.join(args.data, "annotations/instances_val2017.json")
    # val_coco_root = os.path.join(args.data, "val2017")
    # train_annotate = os.path.join(args.data, "annotations/instances_train2017.json")
    # train_coco_root = os.path.join(args.data, "train2017")

    # cocoGt = COCO(annotation_file=val_annotate)
    # val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    # train_coco = COCODetection(train_coco_root, train_annotate, train_trans)
    coco_dataset = exp.get_dataset(
        args.dataset,
        train_transform=train_trans,
        val_transform=val_trans
    )
    cocoGt = coco_dataset.coco
    val_coco = coco_dataset.val
    train_coco = coco_dataset.train

    #print("Number of labels: {}".format(train_coco.labelnum))
    train_dataloader = DataLoader(train_coco, batch_size=args.batch_size, shuffle=True, num_workers=4)
    # set shuffle=True in DataLoader
    # mlperf_log.ssd_print(key=# mlperf_log.INPUT_SHARD, value=None)
    # mlperf_log.ssd_print(key=# mlperf_log.INPUT_ORDER)
    # mlperf_log.ssd_print(key=# mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size)

    ssd300 = SSD300(train_coco.labelnum)
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint)
        ssd300.load_state_dict(od["model"])

    ssd300.train()
    ssd300 = ssd300.to(device)
    loss_func = Loss(dboxes).to(device)

    current_lr = 1e-3
    current_momentum = 0.9
    current_weight_decay = 5e-4

    optim = torch.optim.SGD(
        ssd300.parameters(),
        lr=current_lr,
        momentum=current_momentum,
        weight_decay=current_weight_decay
    )

    # mlperf_log.ssd_print(key=# mlperf_log.OPT_NAME, value="SGD")
    # mlperf_log.ssd_print(key=# mlperf_log.OPT_LR, value=current_lr)
    # mlperf_log.ssd_print(key=# mlperf_log.OPT_MOMENTUM, value=current_momentum)
    # mlperf_log.ssd_print(key=# mlperf_log.OPT_WEIGHT_DECAY,  value=current_weight_decay)

    avg_loss = 0.0
    inv_map = {v:k for k,v in val_coco.label_map.items()}

    # mlperf_log.ssd_print(key=# mlperf_log.TRAIN_LOOP)

    train_loss = 0
    for it, (img, img_size, bbox, label) in dataloop(train_dataloader, wrapper=args.wrapper):
        it.set_count(args.batch_size)

        img = Variable(img.to(device), requires_grad=True)

        ploc, plabel = ssd300(img)

        trans_bbox = bbox.transpose(1,2).contiguous()

        trans_bbox = trans_bbox.to(device)
        label = label.to(device)

        gloc = Variable(trans_bbox, requires_grad=False)
        glabel = Variable(label, requires_grad=False)

        loss = loss_func(ploc, plabel, gloc, glabel)

        if not np.isinf(loss.item()):
            avg_loss = 0.999 * avg_loss + 0.001 * loss.item()

        it.log(loss=loss.item())

        optim.zero_grad()
        loss.backward()
        optim.step()
Esempio n. 6
0
def main(exp):
    # dataset to use
    dataset: Argument & str

    # batch size
    batch_size: Argument & int = default(128)

    # number of predictive factors
    # [alias: -f]
    factors: Argument & int = default(8)

    # size of hidden layers for MLP
    layers: Argument = default("64,32,16,8")

    # number of negative examples per interaction
    # [alias: -n]
    negative_samples: Argument & int = default(4)

    # learning rate for optimizer
    # [alias: -l]
    learning_rate: Argument & float = default(0.001)

    # rank for test examples to be considered a hit
    # [alias: -k]
    topk: Argument & int = default(10)

    layer_sizes = [int(x) for x in layers.split(",")]

    torch_settings = init_torch()
    device = torch_settings.device

    # Load Data
    # ------------------------------------------------------------------------------------------------------------------
    print('Loading data')
    with exp.time('loading_data'):
        t1 = time.time()

        train_dataset = exp.get_dataset(dataset, nb_neg=negative_samples).train

        # mlperf_log.ncf_print(key=# mlperf_log.INPUT_BATCH_SIZE, value=batch_size)
        # mlperf_log.ncf_print(key=# mlperf_log.INPUT_ORDER)  # set shuffle=True in DataLoader
        train_dataloader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=torch_settings.workers,
            pin_memory=True)

        nb_users, nb_items = train_dataset.nb_users, train_dataset.nb_items

        print('Load data done [%.1f s]. #user=%d, #item=%d, #train=%d' %
              (time.time() - t1, nb_users, nb_items, train_dataset.mat.nnz))
    # ------------------------------------------------------------------------------------------------------------------

    # Create model
    model = NeuMF(nb_users,
                  nb_items,
                  mf_dim=factors,
                  mf_reg=0.,
                  mlp_layer_sizes=layer_sizes,
                  mlp_layer_regs=[0. for i in layer_sizes]).to(device)
    print(model)
    print("{} parameters".format(utils.count_parameters(model)))

    # Save model text description
    run_dir = exp.results_directory()
    with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
        file.write(str(model))

    # Add optimizer and loss to graph
    # mlperf_log.ncf_print(key=# mlperf_log.OPT_LR, value=learning_rate)
    beta1, beta2, epsilon = 0.9, 0.999, 1e-8

    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(beta1, beta2),
                                 lr=learning_rate,
                                 eps=epsilon)

    # mlperf_log.ncf_print(key=# mlperf_log.MODEL_HP_LOSS_FN, value=# mlperf_log.BCE)
    criterion = nn.BCEWithLogitsLoss().to(device)

    model.train()

    wrapper = iteration_wrapper(exp, sync=None)

    for it, (user, item, label) in dataloop(train_dataloader, wrapper=wrapper):
        it.set_count(batch_size)

        user = torch.autograd.Variable(user, requires_grad=False).to(device)
        item = torch.autograd.Variable(item, requires_grad=False).to(device)
        label = torch.autograd.Variable(label, requires_grad=False).to(device)

        outputs = model(user, item)
        loss = criterion(outputs, label)
        it.log(loss=loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Esempio n. 7
0
def main(exp):
    # discount factor (default: 0.99)
    gamma: Argument & float = default(0.99)

    # render the environment
    render: Argument & bool = default(False)

    # seed for the environment
    seed: Argument & int = default(1234)

    # length of one episode
    episode_length: Argument & int = default(500)

    torch_settings = init_torch()
    device = torch_settings.device

    env = gym.make('CartPole-v0')
    env.seed(seed)

    policy = Policy()
    optimizer = optim.Adam(policy.parameters(), lr=1e-2)
    eps = np.finfo(np.float32).eps.item()

    print(torch_settings)

    def select_action(state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = policy(state)
        m = Categorical(probs)
        action = m.sample()
        policy.saved_log_probs.append(m.log_prob(action))
        return action.item()

    def finish_episode():
        R = 0
        policy_loss = []
        returns = []

        for r in policy.rewards[::-1]:
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + eps)

        for log_prob, R in zip(policy.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)

        optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        optimizer.step()

        del policy.rewards[:]
        del policy.saved_log_probs[:]

    running_reward = 10

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)

    for it, _ in dataloop(count(), wrapper=wrapper):
        it.set_count(episode_length)

        state, ep_reward = env.reset(), 0

        for t in range(episode_length):

            action = select_action(state)

            state, reward, done, _ = env.step(action)
            policy.rewards.append(reward)
            ep_reward += reward

            # we actually do not care about solving the thing
            if done:
                state, ep_reward = env.reset(), 0

        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        it.log(reward=running_reward)
        finish_episode()
Esempio n. 8
0
def main(exp):
    torch_settings = init_torch()

    # Degree of the polynomial
    poly_degree: Argument & int = default(4)

    # Number of examples per batch
    batch_size: Argument & int = default(64)

    torch_settings = init_torch()
    device = torch_settings.device

    W_target = torch.randn(poly_degree, 1) * 5
    b_target = torch.randn(1) * 5

    def make_features(x):
        """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
        x = x.unsqueeze(1)
        return torch.cat([x**i for i in range(1, poly_degree + 1)], 1)

    def f(x):
        """Approximated function."""
        return x.mm(W_target) + b_target.item()

    def poly_desc(W, b):
        """Creates a string description of a polynomial."""
        result = 'y = '
        for i, w in enumerate(W):
            result += '{:+.2f} x^{} '.format(w, len(W) - i)
        result += '{:+.2f}'.format(b[0])
        return result

    def get_batch():
        """Builds a batch i.e. (x, f(x)) pair."""
        random = torch.randn(batch_size)
        x = make_features(random)
        y = f(x)
        return x, y

    def dataset():
        while True:
            yield get_batch()

    # Define model
    fc = torch.nn.Linear(W_target.size(0), 1)
    fc.to(device)

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)

    for it, (batch_x, batch_y) in dataloop(dataset(), wrapper=wrapper):
        it.set_count(batch_size)

        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        # Reset gradients
        fc.zero_grad()

        # Forward pass
        output = F.smooth_l1_loss(fc(batch_x), batch_y)
        loss = output.item()

        it.log(loss=loss)

        # Backward pass
        output.backward()

        # Apply gradients
        for param in fc.parameters():
            param.data.add_(-0.01 * param.grad.data)

    print('==> Learned function:\t', poly_desc(fc.weight.view(-1), fc.bias))
    print('==> Actual function:\t', poly_desc(W_target.view(-1), b_target))
Esempio n. 9
0
def main(exp):

    # Batch size
    batch_size: Argument & int = default(256)

    # Dataset to use
    dataset: Argument

    torch_settings = init_torch()
    device = torch_settings.device
    dataset = exp.get_dataset(dataset)

    kwargs = {
        'num_workers': 1,
        'pin_memory': True
    } if torch_settings.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        dataset.train,
        batch_size=batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset.test,
        batch_size=batch_size,
        shuffle=True,
        **kwargs,
    )

    model = VAE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss_function(recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD

    def test(epoch):
        # Not tested
        model.eval()
        test_loss = 0

        with torch.no_grad():

            for i, (data, _) in enumerate(test_loader):
                data = data.to(device)
                recon_batch, mu, logvar = model(data)
                test_loss += loss_function(recon_batch, data, mu,
                                           logvar).item()

                if i == 0:
                    n = min(data.size(0), 8)
                    comparison = torch.cat([
                        data[:n],
                        recon_batch.view(batch_size, 1, 28, 28)[:n]
                    ])
                    save_image(comparison.cpu(),
                               'results/reconstruction_' + str(epoch) + '.png',
                               nrow=n)

        test_loss /= len(test_loader.dataset)
        print('====> Test set loss: {:.4f}'.format(test_loss))

    model.train()

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)

    for it, (data, target) in dataloop(train_loader, wrapper=wrapper):
        it.set_count(len(data))

        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        it.log(loss=loss.item())
        optimizer.step()
Esempio n. 10
0
def main(exp):

    # Algorithm to use: a2c | ppo | acktr
    algorithm: Argument = default("a2c")

    # Gail epochs (default: 5)
    gail_epoch: Argument & int = default(5)

    # Learning rate (default: 7e-4)
    lr: Argument & float = default(7e-4)

    # Directory that contains expert demonstrations for gail
    gail_experts_dir: Argument = default("./gail_experts")

    # Gail batch size (default: 128)
    gail_batch_size: Argument & int = default(128)

    # Do imitation learning with gail
    gail: Argument & bool = default(False)

    # RMSprop optimizer epsilon (default: 1e-5)
    eps: Argument & float = default(1e-5)

    # RMSprop optimizer apha (default: 0.99)
    alpha: Argument & float = default(0.99)

    # discount factor for rewards (default: 0.99)
    gamma: Argument & float = default(0.99)

    # use generalized advantage estimation
    use_gae: Argument & bool = default(False)

    # gae lambda parameter (default: 0.95)
    gae_lambda: Argument & float = default(0.95)

    # entropy term coefficient (default: 0.01)
    entropy_coef: Argument & float = default(0.01)

    # value loss coefficient (default: 0.5)
    value_loss_coef: Argument & float = default(0.5)

    # max norm of gradients (default: 0.5)
    max_grad_norm: Argument & float = default(0.5)

    # sets flags for determinism when using CUDA (potentially slow!)
    cuda_deterministic: Argument & bool = default(False)

    # how many training CPU processes to use (default: 16)
    num_processes: Argument & int = default(16)

    # number of forward steps in A2C (default: 5)
    num_steps: Argument & int = default(5)

    # number of ppo epochs (default: 4)
    ppo_epoch: Argument & int = default(4)

    # number of batches for ppo (default: 32)
    num_mini_batch: Argument & int = default(32)

    # ppo clip parameter (default: 0.2)
    clip_param: Argument & float = default(0.2)

    # # log interval, one log per n updates (default: 10)
    # log_interval: Argument & int = default(10)

    # # save interval, one save per n updates (default: 100)
    # save_interval: Argument & int = default(100)

    # # eval interval, one eval per n updates (default: None)
    # eval_interval: Argument & int = default(None)

    # number of environment steps to train (default: 10e6)
    num_env_steps: Argument & int = default(10e6)

    # environment to train on (default: PongNoFrameskip-v4)
    env_name: Argument = default('PongNoFrameskip-v4')

    # directory to save agent logs (default: /tmp/gym)
    log_dir: Argument = default(None)

    # directory to save agent logs (default: ./trained_models/)
    save_dir: Argument = default('./trained_models/')

    # compute returns taking into account time limits
    use_proper_time_limits: Argument & bool = default(False)

    # use a recurrent policy
    recurrent_policy: Argument & bool = default(False)

    # use a linear schedule on the learning rate')
    use_linear_lr_decay: Argument & bool = default(False)

    # Seed to use
    seed: Argument & int = default(1234)

    # Number of iterations
    iterations: Argument & int = default(10)

    # we compute steps/sec
    batch_size = num_processes

    torch_settings = init_torch()
    device = torch_settings.device

    assert algorithm in ['a2c', 'ppo', 'acktr']

    if recurrent_policy:
        assert algorithm in ['a2c', 'ppo'], \
            'Recurrent policy is not implemented for ACKTR'

    num_updates = int(num_env_steps) // num_steps // num_processes

    envs = make_vec_envs(env_name, seed, num_processes, gamma, log_dir, device,
                         False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': recurrent_policy})
    actor_critic.to(device)

    if algorithm == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               value_loss_coef,
                               entropy_coef,
                               lr=lr,
                               eps=eps,
                               alpha=alpha,
                               max_grad_norm=max_grad_norm)
    elif algorithm == 'ppo':
        agent = algo.PPO(actor_critic,
                         clip_param,
                         ppo_epoch,
                         num_mini_batch,
                         value_loss_coef,
                         entropy_coef,
                         lr=lr,
                         eps=eps,
                         max_grad_norm=max_grad_norm)
    elif algorithm == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               value_loss_coef,
                               entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(num_steps, num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)
    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(num_env_steps) // num_steps // num_processes

    wrapper = iteration_wrapper(exp, sync=torch_settings.sync)

    for it, j in dataloop(count(), wrapper=wrapper):
        it.set_count(batch_size)

        if use_linear_lr_decay:
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if algorithm == "acktr" else lr)

        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()
        # ---
        rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda,
                                 use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        it.log(
            value_loss=value_loss,
            action_loss=action_loss,
        )

        rollouts.after_update()

        total_num_steps = (j + 1) * num_processes * num_steps

        # if j % log_interval == 0 and len(episode_rewards) > 1:
        #     end = time.time()
        #     print(
        #         "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n".
        #         format(j, total_num_steps,
        #             int(total_num_steps / (end - start)),
        #             len(episode_rewards),
        #             np.mean(episode_rewards),
        #             np.median(episode_rewards),
        #             np.min(episode_rewards),
        #             np.max(episode_rewards), dist_entropy,
        #             value_loss, action_loss))
    envs.close()