Esempio n. 1
0
 def get_conf_matrix_for_preds(self, preds, gt_labels):
     conf = np.zeros((2, 2), dtype=np.int32)
     pred_label_arr = to_numpy(preds > 0.5).astype(np.int32).reshape(-1)
     # use np.around to make sure we get 0 and 1 correctly e.g. 0.999
     # should be 1  not 0
     gt_label_arr = np.around(to_numpy(gt_labels)).astype(np.int32)
     for b in range(gt_labels.size(0)):
         conf[gt_label_arr[b], pred_label_arr[b]] += 1
     return conf, gt_label_arr, pred_label_arr
Esempio n. 2
0
def get_returns(tasks, episodes):
    ret = [
        -np.linalg.norm(np.array(episode.observations.cpu()) -
                        np.expand_dims(tasks[taskIdx][1]['goal'], 0),
                        axis=2).sum(0)
        for taskIdx, episode in enumerate(episodes)
    ]
    return to_numpy(ret)
Esempio n. 3
0
    def train(self,
              train=True,
              viz_images=False,
              save_embedding=True,
              use_emb_data=False,
              log_prefix=''):
        print("Begin training")
        args = self.config.args
        log_freq_iters = args.log_freq_iters if train else 10
        dataloader = self.dataloader
        device = self.config.get_device()
        if use_emb_data:
            train_data_size = dataloader.get_h5_data_size(train)
        else:
            train_data_size = dataloader.get_data_size(train)
        train_data_idx_list = list(range(0, train_data_size))

        # Reset log counter
        train_step_count, test_step_count = 0, 0
        self.set_model_device(device)

        result_dict = {
            'data_info': {
                'path': [],
                'info': [],
            },
            'emb': {
                'train_img_emb': [],
                'test_img_emb': [],
                'train_gt': [],
                'train_pred': [],
                'test_gt': [],
                'tese_pred': [],
            },
            'output': {
                'gt': [],
                'pred': [],
                'test_f1_score': [],
                'test_wt_f1_score': [],
                'test_conf': [],
            },
            'conf': {
                'train': [],
                'test': [],
            }
        }
        num_epochs = args.num_epochs if train else 1

        for e in range(num_epochs):
            if train:
                iter_order = np.random.permutation(train_data_idx_list)
            else:
                iter_order = np.arange(train_data_size)

            batch_size = args.batch_size if train else 32
            num_batches = train_data_size // batch_size
            data_idx = 0

            n_classes = args.classif_num_classes
            result_dict['conf']['train'].append(
                np.zeros((n_classes, n_classes), dtype=np.int32))
            for k in ['gt', 'pred']:
                result_dict['output'][k] = []
            for k in ['train_img_emb', 'train_gt', 'train_pred']:
                result_dict['emb'][k] = []

            for batch_idx in range(num_batches):
                # Get raw data from the dataloader.
                batch_data = []
                # for b in range(batch_size):
                batch_get_start_time = time.time()

                while len(batch_data) < batch_size and data_idx < len(
                        iter_order):
                    actual_data_idx = iter_order[data_idx]
                    if use_emb_data:
                        data = dataloader.get_h5_train_data_at_idx(
                            actual_data_idx, train=train)
                    else:
                        data = dataloader.get_train_data_at_idx(
                            actual_data_idx, train=train)
                    batch_data.append(data)
                    data_idx = data_idx + 1

                batch_get_end_time = time.time()
                # print("Data time: {:.4f}".format(
                # batch_get_end_time - batch_get_start_time))

                # Process raw batch data
                proc_data_start_time = time.time()
                x_dict = self.process_raw_batch_data(batch_data)
                # Now collate the batch data together
                x_tensor_dict = self.collate_batch_data_to_tensors(x_dict)
                proc_data_end_time = time.time()

                run_batch_start_time = time.time()
                model_fn = self.run_emb_model_on_batch \
                    if use_emb_data else self.run_model_on_batch
                batch_result_dict = model_fn(x_tensor_dict,
                                             batch_size,
                                             train=train,
                                             save_preds=True)
                run_batch_end_time = time.time()

                # print("Batch get: {:4f}   \t  proc data: {:.4f}  \t  run: {:.4f}".format(
                # batch_get_end_time - batch_get_start_time,
                # proc_data_end_time - proc_data_start_time,
                # run_batch_end_time - run_batch_start_time
                # ))
                if args.loss_type == 'classif':
                    result_dict['conf']['train'][-1] += batch_result_dict[
                        'conf']

                result_dict['output']['gt'].append(
                    batch_result_dict['gt_label'])
                result_dict['output']['pred'].append(
                    batch_result_dict['pred_label'])
                for b in range(len(batch_data)):
                    result_dict['emb']['train_img_emb'].append(
                        to_numpy(batch_result_dict['img_emb'][b]))
                    result_dict['emb']['train_gt'].append(
                        batch_result_dict['gt_label'][b])
                    result_dict['emb']['train_pred'].append(
                        batch_result_dict['pred_label'][b])

                self.print_train_update_to_console(e, num_epochs, batch_idx,
                                                   num_batches,
                                                   train_step_count,
                                                   batch_result_dict)

                plot_images = viz_images and train \
                    and train_step_count %  log_freq_iters == 0
                plot_loss = train \
                    and train_step_count % args.print_freq_iters == 0

                if train:
                    self.plot_train_update_to_tensorboard(
                        x_dict,
                        x_tensor_dict,
                        batch_result_dict,
                        train_step_count,
                        plot_loss=plot_loss,
                        plot_images=plot_images,
                    )

                if train:
                    if train_step_count % log_freq_iters == 0:
                        self.log_model_to_tensorboard(train_step_count)

                    # Save trained models
                    if train_step_count % args.save_freq_iters == 0:
                        self.save_checkpoint(train_step_count)

                    # Run current model on val/test data.
                    if train_step_count % args.test_freq_iters == 0:
                        # Remove old stuff from memory
                        x_dict = None
                        x_tensor_dict = None
                        batch_result_dict = None
                        torch.cuda.empty_cache()
                        for k in ['test_img_emb', 'test_gt', 'test_pred']:
                            result_dict['emb'][k] = []

                        test_batch_size = args.batch_size
                        if use_emb_data:
                            test_data_size = self.dataloader.get_h5_data_size(
                                train=False)
                        else:
                            test_data_size = self.dataloader.get_data_size(
                                train=False)
                        num_batch_test = test_data_size // test_batch_size
                        if test_data_size % test_batch_size != 0:
                            num_batch_test += 1
                        # Do NOT sort the test data.
                        test_iter_order = np.arange(test_data_size)
                        test_data_idx, total_test_loss = 0, 0

                        all_gt_label_list, all_pred_label_list = [], []

                        self.set_model_to_eval()

                        result_dict['conf']['test'].append(
                            np.zeros((n_classes, n_classes), dtype=np.int32))

                        print(bcolors.c_yellow("==== Test begin ==== "))
                        for test_e in range(num_batch_test):
                            batch_data = []

                            while (len(batch_data) < test_batch_size
                                   and test_data_idx < len(test_iter_order)):
                                if use_emb_data:
                                    data = dataloader.get_h5_train_data_at_idx(
                                        test_iter_order[test_data_idx],
                                        train=False)
                                else:
                                    data = dataloader.get_train_data_at_idx(
                                        test_iter_order[test_data_idx],
                                        train=False)
                                batch_data.append(data)
                                test_data_idx = test_data_idx + 1

                            # Process raw batch data
                            x_dict = self.process_raw_batch_data(batch_data)
                            # Now collate the batch data together
                            x_tensor_dict = self.collate_batch_data_to_tensors(
                                x_dict)
                            with torch.no_grad():
                                model_fn = self.run_emb_model_on_batch \
                                    if use_emb_data else self.run_model_on_batch
                                batch_result_dict = model_fn(x_tensor_dict,
                                                             test_batch_size,
                                                             train=False,
                                                             save_preds=True)
                                total_test_loss += batch_result_dict[
                                    'total_loss']
                                all_gt_label_list.append(
                                    batch_result_dict['gt_label'])
                                all_pred_label_list.append(
                                    batch_result_dict['pred_label'])
                                for b in range(len(batch_data)):
                                    result_dict['emb']['test_img_emb'].append(
                                        to_numpy(
                                            batch_result_dict['img_emb'][b]))

                            result_dict['conf']['test'][
                                -1] += batch_result_dict['conf']

                            self.print_train_update_to_console(
                                e,
                                num_epochs,
                                test_e,
                                num_batch_test,
                                train_step_count,
                                batch_result_dict,
                                train=False)

                            plot_images = test_e == 0
                            plot_loss = True
                            self.plot_train_update_to_tensorboard(
                                x_dict,
                                x_tensor_dict,
                                batch_result_dict,
                                test_step_count,
                                plot_loss=plot_loss,
                                plot_images=plot_images,
                                log_prefix='/test/')

                            test_step_count += 1
                        # Calculate metrics
                        gt_label = np.concatenate(all_gt_label_list)
                        pred_label = np.concatenate(all_pred_label_list)
                        normal_f1 = f1_score(gt_label, pred_label)
                        wt_f1 = f1_score(gt_label,
                                         pred_label,
                                         average='weighted')
                        self.logger.summary_writer.add_scalar(
                            '/metrics/test/normal_f1', normal_f1,
                            test_step_count)
                        self.logger.summary_writer.add_scalar(
                            '/metrics/test/wt_f1', wt_f1, test_step_count)
                        result_dict['output']['test_f1_score'].append(
                            normal_f1)
                        result_dict['output']['test_wt_f1_score'].append(wt_f1)
                        result_dict['output']['test_conf'].append(
                            result_dict['conf']['test'][-1])
                        result_dict['emb']['test_gt'] = np.copy(gt_label)
                        result_dict['emb']['test_pred'] = np.copy(pred_label)

                        # Plot the total loss on the entire dataset. Hopefull,
                        # this would decrease over time.
                        self.logger.summary_writer.add_scalar(
                            '/test/all_batch_loss/loss_avg',
                            total_test_loss / max(num_batch_test, 1),
                            test_step_count)
                        self.logger.summary_writer.add_scalar(
                            '/test/all_batch_loss/loss', total_test_loss,
                            test_step_count)

                        print(
                            bcolors.c_yellow(
                                "Test:  \t          F1: {:.4f}\n"
                                "       \t       Wt-F1: {:.4f}\n"
                                "       \t        conf:\n{}".format(
                                    normal_f1, wt_f1,
                                    np.array_str(
                                        result_dict['conf']['test'][-1],
                                        precision=0))))
                        print(
                            bcolors.c_yellow(' ==== Test Epoch conf end ===='))

                    x_dict = None
                    x_tensor_dict = None
                    batch_result_dict = None
                    torch.cuda.empty_cache()
                    if train:
                        self.set_model_to_train()

                train_step_count += 1
                torch.cuda.empty_cache()

            self.did_end_train_epoch()

            for k in ['gt', 'pred']:
                result_dict['output'][k] = np.concatenate(
                    result_dict['output'][k]).astype(np.int32)

            normal_f1 = f1_score(result_dict['output']['gt'],
                                 result_dict['output']['pred'])
            wt_f1 = f1_score(result_dict['output']['gt'],
                             result_dict['output']['pred'],
                             average='weighted')
            self.logger.summary_writer.add_scalar('/metrics/train/normal_f1',
                                                  normal_f1, train_step_count)
            self.logger.summary_writer.add_scalar('/metrics/train/wt_f1',
                                                  wt_f1, train_step_count)

            if args.loss_type == 'classif':
                print(
                    bcolors.c_red("Train:  \t          F1: {:.4f}\n"
                                  "        \t       Wt-F1: {:.4f}\n"
                                  "        \t        conf:\n{}".format(
                                      normal_f1, wt_f1,
                                      np.array_str(
                                          result_dict['conf']['train'][-1],
                                          precision=0))))
                # Find min wt f1
                if len(result_dict['output']['test_wt_f1_score']) > 0:
                    max_f1_idx = np.argmax(
                        result_dict['output']['test_wt_f1_score'])
                    print(
                        bcolors.c_cyan(
                            "Max test wt f1:\n"
                            "               \t    F1: {:.4f}\n"
                            "               \t    Wt-F1: {:.4f}\n"
                            "               \t    conf:\n{}".format(
                                result_dict['output']['test_f1_score']
                                [max_f1_idx], result_dict['output']
                                ['test_wt_f1_score'][max_f1_idx],
                                np.array_str(
                                    result_dict['conf']['test'][max_f1_idx],
                                    precision=0))))

                save_emb_data_to_h5(args.result_dir, result_dict)
                print(' ==== Epoch done ====')

        for k in ['train_gt', 'train_pred']:
            result_dict['emb'][k] = np.array(result_dict['emb'][k])

        return result_dict
Esempio n. 4
0
    def get_embeddings_for_pretrained_model(self,
                                            checkpoint_path,
                                            use_train_data=True):
        '''Get embedding for scenes where there only exist a pair of obejcts.
        '''

        args = self.config.args
        emb_trainer = create_voxel_trainer_with_checkpoint(checkpoint_path,
                                                           cuda=args.cuda)

        emb_result_dict = dict(h5=dict(), pkl=dict())

        dataloader = self.dataloader
        device = self.config.get_device()
        train_data_size = dataloader.get_data_size(use_train_data)
        train_data_idx_list = list(range(0, train_data_size))
        emb_trainer.model.to(device)
        iter_order = np.arange(train_data_size)

        batch_size = args.batch_size
        num_batches = train_data_size // batch_size
        if train_data_size % batch_size != 0:
            num_batches += 1
        data_idx, emb_result_idx = 0, 0
        print_freq = num_batches // 10
        if print_freq == 0:
            print_freq = 1

        for batch_idx in range(
                num_batches):  # Get raw data from the dataloader.
            batch_data = []

            # for b in range(batch_size):
            batch_get_start_time = time.time()

            while len(batch_data) < batch_size and data_idx < len(iter_order):
                actual_data_idx = iter_order[data_idx]
                data = dataloader.get_train_data_at_idx(
                    actual_data_idx, actual_data_idx)
                batch_data.append(data)
                data_idx = data_idx + 1

            batch_get_end_time = time.time()
            voxel_data = torch.stack([d['voxels']
                                      for d in batch_data]).to(device)
            voxel_emb = emb_trainer.get_embedding_for_data(voxel_data)
            voxel_emb_arr = to_numpy(voxel_emb)

            # import ipdb; ipdb.set_trace()
            for b, data in enumerate(batch_data):
                assert emb_result_dict['h5'].get(str(emb_result_idx)) is None
                emb_result_dict['h5'][str(emb_result_idx)] = {
                    'emb': voxel_emb_arr[b],
                    'precond_label': data['precond_label'],
                }
                emb_result_dict['pkl'][str(emb_result_idx)] = {
                    'path': data['path'],
                    'before_img_path': data['before_img_path'],
                    'precond_label': data['precond_label'],
                }
                emb_result_idx += 1

            if batch_idx % print_freq == 0:
                print("Got emb for {}/{}".format(batch_idx, num_batches))

        return emb_result_dict
Esempio n. 5
0
    def get_embeddings_for_pretrained_model_with_multiple_objects(
            self, checkpoint_path, use_train_data=True):
        '''Get embedding for scenes where there only exist multiple objects in 
        the scene.
        '''
        args = self.config.args
        emb_trainer = create_voxel_trainer_with_checkpoint(checkpoint_path,
                                                           cuda=args.cuda)
        emb_result_dict = dict(h5=dict(), pkl=dict())

        device = self.config.get_device()
        emb_trainer.model.to(device)

        # Reset scene
        dataloader = self.dataloader
        dataloader.reset_scene_batch_sampler(train=use_train_data,
                                             shuffle=False)

        scene_batch_size = 1
        train_data_size = dataloader.number_of_scene_data(use_train_data)
        num_batch_scenes = train_data_size // scene_batch_size
        if train_data_size % scene_batch_size > 0:
            num_batch_scenes += 1

        data_idx, emb_result_idx = 0, 0
        print_freq = num_batch_scenes // 5
        if print_freq == 0:
            print_freq = 1

        for batch_idx in range(num_batch_scenes):
            batch_data = []
            voxel_data_list = []
            batch_get_start_time = time.time()

            while len(batch_data
                      ) < scene_batch_size and data_idx < train_data_size:
                data = dataloader.get_next_all_object_pairs_for_scene(
                    use_train_data)
                voxel_data_list += data['all_object_pair_voxels']
                batch_data.append(data)
                data_idx += 1

            batch_get_end_time = time.time()

            voxel_data = torch.Tensor(voxel_data_list).to(device)
            voxel_emb = emb_trainer.get_embedding_for_data(voxel_data)
            voxel_emb_arr = to_numpy(voxel_emb)

            result_data_idx = 0
            for b, data in enumerate(batch_data):
                assert emb_result_dict['h5'].get(str(emb_result_idx)) is None
                emb_result_dict['h5'][str(emb_result_idx)] = {
                    'emb': voxel_emb_arr,
                    'precond_label': data['precond_label'],
                }
                emb_result_dict['pkl'][str(emb_result_idx)] = {
                    'path': data['scene_path'],
                    'all_object_pair_path': data['all_object_pair_path'],
                    'precond_label': data['precond_label'],
                }
                emb_result_idx += 1
                result_data_idx += 1

            if batch_idx % print_freq == 0:
                print("Got emb for {}/{}".format(batch_idx, num_batch_scenes))

        return emb_result_dict
Esempio n. 6
0
    def run_model_on_batch(self,
                           x_tensor_dict,
                           batch_size,
                           contrastive_x_tensor_dict=None,
                           train=False,
                           save_preds=False):
        batch_result_dict = {}
        device = self.config.get_device()
        args = self.config.args

        voxel_data = x_tensor_dict['batch_voxel']
        img_emb = self.model.forward_image(voxel_data)
        # TODO: Add hinge loss
        # img_emb = img_emb.squeeze()

        if args.use_bb_in_input:
            img_emb_with_action = torch.cat([
                img_emb, x_tensor_dict['batch_bb_list'],
                x_tensor_dict['batch_obj_before_pose_list'],
                x_tensor_dict['batch_action_list']
            ],
                                            dim=1)
        else:
            img_emb_with_action = torch.cat(
                [img_emb, x_tensor_dict['batch_action_list']], dim=1)

        if args.use_contrastive_loss and train:
            assert contrastive_x_tensor_dict is not None, "Contrastive data is None"
            sim_img_emb = self.model.forward_image(
                contrastive_x_tensor_dict[1]['batch_voxel'])
            diff_img_emb = self.model.forward_image(
                contrastive_x_tensor_dict[2]['batch_voxel'])

            triplet_loss = args.weight_contrastive_loss * \
                self.triplet_loss(img_emb, sim_img_emb, diff_img_emb)
        else:
            triplet_loss = 0

        img_action_emb = self.model.forward_image_with_action(
            img_emb_with_action)

        pred_delta_pose = self.model.forward_predict_delta_pose(img_action_emb)
        # pose_pred_loss = args.weight_bb * self.pose_pred_loss(
        #     pred_delta_pose,  x_tensor_dict['batch_obj_delta_pose_list'])

        if args.loss_type == 'regr':
            position_pred_loss = args.weight_pos * self.pose_pred_loss(
                pred_delta_pose[:, :3],
                x_tensor_dict['batch_obj_delta_pose_list'][:, :3])
            angle_pred_loss = args.weight_angle * self.pose_pred_loss(
                pred_delta_pose[:, 3:],
                x_tensor_dict['batch_obj_delta_pose_list'][:, 3:])
        elif args.loss_type == 'classif':
            n_classes = args.classif_num_classes

            true_label_x = x_tensor_dict['batch_obj_delta_pose_class_list'][:,
                                                                            0]
            true_label_y = x_tensor_dict['batch_obj_delta_pose_class_list'][:,
                                                                            1]
            position_pred_loss_x = args.weight_pos * self.pose_pred_loss(
                pred_delta_pose[:, :n_classes], true_label_x)
            position_pred_loss_y = args.weight_pos * self.pose_pred_loss(
                pred_delta_pose[:, n_classes:], true_label_y)
            position_pred_loss = position_pred_loss_x + position_pred_loss_y

            angle_pred_loss = 0

        else:
            raise ValueError("Invalid loss type: {}".format(args.loss_type))

        # img_action_with_delta_pose = torch.cat(
        #    [img_action_emb, x_tensor_dict['batch_other_bb_pred_list']], dim=1)
        # pred_img_emb = self.model.forward_predict_original_img_emb(
        #    img_action_with_delta_pose)
        #inv_model_loss = args.weight_inv_model * self.inv_model_loss(
        #   pred_img_emb, img_emb)
        inv_model_loss = 0.

        # total_loss = pose_pred_loss + inv_model_loss
        total_loss = position_pred_loss + angle_pred_loss + triplet_loss

        if train:
            self.opt.zero_grad()
            total_loss.backward()
            self.opt.step()

        batch_result_dict['img_emb'] = img_emb
        batch_result_dict['img_action_emb'] = img_action_emb
        batch_result_dict['pred_delta_pose'] = pred_delta_pose
        batch_result_dict[
            'pose_pred_loss'] = position_pred_loss + angle_pred_loss
        batch_result_dict['position_pred_loss'] = position_pred_loss
        if args.loss_type == 'classif':
            batch_result_dict['position_pred_loss_x'] = position_pred_loss_x
            batch_result_dict['position_pred_loss_y'] = position_pred_loss_y

            # Save conf matrix.
            _, pred_x = torch.max(pred_delta_pose[:, :n_classes], dim=1)
            _, pred_y = torch.max(pred_delta_pose[:, n_classes:], dim=1)
            conf_x = np.zeros((n_classes, n_classes), dtype=np.int32)
            conf_y = np.zeros((n_classes, n_classes), dtype=np.int32)
            for b in range(true_label_x.size(0)):
                conf_x[true_label_x[b].item(), pred_x[b].item()] += 1
                conf_y[true_label_y[b].item(), pred_y[b].item()] += 1
            batch_result_dict['conf_x'] = conf_x
            batch_result_dict['conf_y'] = conf_y

        batch_result_dict['angle_pred_loss'] = angle_pred_loss
        batch_result_dict['inv_model_loss'] = inv_model_loss
        batch_result_dict['triplet_loss'] = triplet_loss
        batch_result_dict['total_loss'] = total_loss

        if not train and save_preds:
            batch_result_dict['pos_gt'] = \
                to_numpy(x_tensor_dict['batch_obj_delta_pose_list'][:, :3])
            batch_result_dict['pos_class_gt'] = \
                to_numpy(x_tensor_dict['batch_obj_delta_pose_class_list'])
            if args.loss_type == 'regr':
                batch_result_dict['pos_pred'] = to_numpy(
                    pred_delta_pose[:, :3])
            else:
                batch_result_dict['pos_pred'] = to_numpy(pred_delta_pose)

        return batch_result_dict
Esempio n. 7
0
    def step(self,
             train_episodes,
             valid_episodes,
             params,
             max_kl=1e-3,
             cg_iters=10,
             cg_damping=1e-2,
             ls_max_steps=10,
             ls_backtrack_ratio=0.5):
        num_tasks = len(train_episodes)
        logs = {}

        # Compute the surrogate loss
        old_losses, old_kls, old_pis = [], [], []
        for (param, train, valid) in zip(params, train_episodes,
                                         valid_episodes):
            ls, kl, pi = self.surrogate_loss(train, valid, param, old_pi=None)
            old_losses.append(ls)
            old_kls.append(kl)
            old_pis.append(pi)

        logs['loss_before'] = to_numpy(old_losses)
        logs['kl_before'] = to_numpy(old_kls)

        old_loss = sum(old_losses) / num_tasks
        grads = torch.autograd.grad(old_loss,
                                    self.policy.parameters(),
                                    retain_graph=True)
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        old_kl = sum(old_kls) / num_tasks
        hessian_vector_product = self.hessian_vector_product(
            old_kl, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(
            stepdir, hessian_vector_product(stepdir, retain_graph=False))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 1.0
        for _ in range(ls_max_steps):
            new_params = vector_to_parameters(old_params - step_size * step,
                                              self.policy.parameters())
            for par, newPar in zip(self.policy.parameters(), new_params):
                par.data.copy_(newPar)

            losses, kls = [], []
            for param, train, valid, old_pi in zip(params, train_episodes,
                                                   valid_episodes, old_pis):
                ls, kl, _ = self.surrogate_loss(train,
                                                valid,
                                                param,
                                                old_pi=old_pi)
                losses.append(ls)
                kls.append(kl)

            # losses, kls, _ = self._async_gather([
            #     self.surrogate_loss(train, valid, old_pi=old_pi)
            #     for (train, valid, old_pi)
            #     in zip(zip(*train_futures), valid_futures, old_pis)])

            improve = (sum(losses) / num_tasks) - old_loss
            kl = sum(kls) / num_tasks
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                logs['loss_after'] = to_numpy(losses)
                logs['kl_after'] = to_numpy(kls)
                break
            step_size *= ls_backtrack_ratio
        else:
            new_params = vector_to_parameters(old_params - step_size * step,
                                              self.policy.parameters())
            for par, newPar in zip(self.policy.parameters(), new_params):
                par.data.copy_(newPar)

        return logs