def val_dataloader(self): bs = BatchSampler( RandomSampler(self.test_buffer, replacement=True, num_samples=self.hparams['num_val_batches'] * self.hparams['batch_size']), batch_size=self.hparams['batch_size'], drop_last=False) return DataLoader(self.test_buffer, batch_sampler=bs)
def train_dataloader(self): bs = BatchSampler( RandomSampler(self.train_buffer, replacement=True, num_samples=self.hparams['num_grad_steps'] * self.hparams['batch_size']), batch_size=self.hparams['batch_size'], drop_last=False) return DataLoader(self.train_buffer, batch_sampler=bs)
def mean_importance(model, dataset, loss, batch_size, bar=False): ''' Calculate feature importance by measuring performance reduction when features are imputed with their mean value. Args: model: sklearn model. dataset: PyTorch dataset, such as data.utils.TabularDataset. loss: string descriptor of loss function ('mse', 'cross entropy'). batch_size: number of examples to be processed at once. bar: whether to display progress bar. ''' # Add wrapper if necessary. if isinstance(model, sklearn.base.ClassifierMixin): model = SklearnClassifierWrapper(model) # Setup. input_size = dataset.input_size loader = DataLoader(dataset, batch_sampler=BatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=False)) loss_fn = utils.get_loss_np(loss, reduction='none') scores = [] # Performance with all features. base_loss = validate_sklearn(model, loader, utils.get_loss_np(loss, reduction='mean')) # For imputing with mean. imputation = utils.ReferenceImputation( torch.mean(torch.tensor(dataset.data), dim=0)) if bar: bar = tqdm(total=len(dataset) * input_size) for ind in range(input_size): # Setup. score = 0 N = 0 for x, y in loader: # Impute with mean and make predictions. n = len(x) y_hat = model.predict( imputation.impute_ind(x, ind).cpu().data.numpy()) # Measure loss and compute average. loss = np.mean(loss_fn(y_hat, y.cpu().data.numpy())) score = (score * N + loss * n) / (N + n) N += n if bar: bar.update(n) scores.append(score) return np.stack(scores) - base_loss
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiton_id=0): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiton_id = multiton_id if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') self.batch_size = None self.drop_last = None if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if self.num_workers < 0: raise ValueError('num_workers option cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.dataiter = _DataLoaderIter(self.num_workers, self.pin_memory, self.timeout, self.worker_init_fn, self.multiton_id) self.__initialized = True
def init_loaders(self, sample_keys=None): """ Method that converts data and labels to instances of class torch.utils.data.DataLoader Returns: a dictionary with the same keys as data_dict and label_dict. Each element of the dictionary is an instance of torch.utils.data.DataLoader that yields paired elements of data and labels """ # Convert the data to Dataset dataset_dict = self.init_datasets() # If the Dataset implements collate_fn, that is used. Otherwise, default_collate is used if hasattr(dataset_dict["train"], "collate_fn") and callable( getattr(dataset_dict["train"], "collate_fn")): collate_fn = dataset_dict["train"].collate_fn else: collate_fn = default_collate # If 'iters_per_epoch' is defined, then a fixed number of random sample batches from the training set # are drawn per epoch. # Otherwise, an epoch is defined by a full run through all of the data in the dataloader. if self.config_dict.get("iters_per_epoch") is not None: num_samples = (self.config_dict["iters_per_epoch"] * self.config_dict["batch_size"]) if sample_keys is None: sample_keys = ["train"] else: if sample_keys is None: sample_keys = [] loaders_dict = {} for key in dataset_dict.keys(): if key in sample_keys: loaders_dict[key] = DataLoader( dataset_dict[key], batch_sampler=BatchSampler( RandomSampler(dataset_dict[key], replacement=True, num_samples=num_samples), batch_size=self.config_dict["batch_size"], drop_last=False, ), collate_fn=collate_fn, num_workers=self.num_workers, pin_memory=True, ) else: loaders_dict[key] = DataLoader( dataset_dict[key], batch_size=self.config_dict["batch_size"], collate_fn=collate_fn, num_workers=self.num_workers, pin_memory=True, ) return loaders_dict
def init_dataloader(dataset, batch_size=32, random=True): sampler = RandomSampler(dataset) if random else SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=False) loader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=dataset.collater, num_workers=4) return loader
def build_dataloader(dataset:Dataset, batch_size:int, sequential:bool=True) -> DataLoader: if sequential: sampler = SequentialSampler(dataset) else: sampler = RandomSampler(dataset, replacement=True) batch_sampler = BatchSampler(sampler, batch_size, drop_last=False) dataloader = DataLoader(dataset, batch_sampler=batch_sampler) return dataloader
def __call__(self, inputs: List[List[str or int]] or List[Tuple[List[str or int], Any]], *args): dataset = MapStyleDataset(inputs, self.min_seq_len, self.max_seq_len) batch_sampler = BatchSampler(self.sampler(dataset), batch_size=self.batch_size, drop_last=self.drop_last) loader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=dataset.batch_sequences) return loader
def __iter__(self): if self.generator is None: self.generator = torch.Generator().manual_seed(self.initial_seed) self.sample_task_indices() self.create_sampler() self.batch_sampler = BatchSampler(self.sampler, batch_size=self.batch_size, drop_last=self.drop_last) self.iter_sampler = iter(self.batch_sampler) self.is_first_batch = True self.current_task_iteration += 1 return self
def __iter__(self): smp = RandomSampler(torch.arange(self.n_chunks)) for top in BatchSampler(smp, self.batch_size, False): # don't drop last! offsets = torch.randint(0, self.remainder, (self.batch_size, )) top = tuple(o + (t * self.chunk_length) for t, o in zip(top, offsets)) for start in range(self.n_per_chunk): # start indices of the batch yield tuple(t + (start * self.seq_len) for t in top)
def __init__(self, dataset, batch_size, negative_sampling=False, num_sampling_users=0, num_workers=0, collate_fn=None): self.dataset = dataset # type: RecommendationDataset self.num_sampling_users = num_sampling_users self.num_workers = num_workers self.batch_size = batch_size self.negative_sampling = negative_sampling if self.num_sampling_users == 0: self.num_sampling_users = batch_size assert self.num_sampling_users >= batch_size, 'num_sampling_users should be at least equal to the batch_size' self.batch_collator = BatchCollator( batch_size=self.batch_size, negative_sampling=self.negative_sampling) # Wrapping a BatchSampler within a BatchSampler # in order to fetch the whole mini-batch at once # from the dataset instead of fetching each sample on its own batch_sampler = BatchSampler(BatchSampler( RandomSampler(dataset), batch_size=self.num_sampling_users, drop_last=False), batch_size=1, drop_last=False) if collate_fn is None: self._collate_fn = self.batch_collator.collate self._use_default_data_generator = True else: self._collate_fn = collate_fn self._use_default_data_generator = False self._dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=self._collate)
def test_batch_up_image_loader_with_batch_sampler(image): batch_size = 3 dataset = () batch_sampler = BatchSampler(SequentialSampler(dataset), batch_size, drop_last=False) loader = DataLoader(dataset, batch_sampler=batch_sampler) batched_up_image = utils.batch_up_image(image, loader=loader) assert extract_batch_size(batched_up_image) == batch_size
def __init__(self, samplers, batch_size): for sampler in samplers: if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)) if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or batch_size <= 0: raise ValueError("batch_size should be a positive integeral value, but got batch_size={}".format(batch_size)) self.samplers = samplers self.batch_size = batch_size self.batch_samplers = [BatchSampler(sampler, self.batch_size, True) for sampler in self.samplers]
def get_val_dataloader(source='scientsbank', origin='answer'): dataset = ds.SemEvalDataset('data/flat_semeval5way_test.csv') dataset.to_val_mode(source, origin) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False) loader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=dataset.collater, num_workers=4) return loader
def __init__(self, sampler, batch_size, shuffle_batches=False, drop_incomplete=False): batches = [ batch for batch in BatchSampler(sampler, batch_size, drop_incomplete) ] super().__init__(batches, shuffle_batches)
def load_data(pkl_paths, use_attr, no_img, batch_size, uncertain_label=False, n_class_attr=2, image_dir='images', resampling=False, resol=299): """ Note: Inception needs (299,299,3) images with inputs scaled between -1 and 1 Loads data with transformations applied, and upsample the minority class if there is class imbalance and weighted loss is not used NOTE: resampling is customized for first attribute only, so change sampler.py if necessary """ resized_resol = int(resol * 256 / 224) is_training = any(['train.pkl' in f for f in pkl_paths]) if is_training: transform = transforms.Compose([ #transforms.Resize((resized_resol, resized_resol)), #transforms.RandomSizedCrop(resol), transforms.ColorJitter(brightness=32 / 255, saturation=(0.5, 1.5)), transforms.RandomResizedCrop(resol), transforms.RandomHorizontalFlip(), transforms.ToTensor(), #implicitly divides by 255 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2]) #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]), ]) else: transform = transforms.Compose([ #transforms.Resize((resized_resol, resized_resol)), transforms.CenterCrop(resol), transforms.ToTensor(), #implicitly divides by 255 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2]) #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]), ]) dataset = CUBDataset(pkl_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform) if is_training: drop_last = True shuffle = True else: drop_last = False shuffle = False if resampling: sampler = BatchSampler(ImbalancedDatasetSampler(dataset), batch_size=batch_size, drop_last=drop_last) loader = DataLoader(dataset, batch_sampler=sampler) else: loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) return loader
def unsupervised_dataloaders(seed=123, val_portion=0.1, test_portion=0.1, device=None, mean_adjustment=False, normalization=False): ''' Create dataloaders with no label information (for self reconstruction). Args: seed: for shuffling dataset. val_portion: portion for validation. test_portion: portion for test. mean_adjustment: whether to adjust x for mean. normalization: whether to adjust x for standard deviation. device: device for torch.Tensors. ''' # Load data. data = sio.loadmat('Mouse-V1-ALM-20180520_thr_7-5.mat') lopge = data['lOPGE'] # Split data. train, val, test = split_data(lopge, seed=seed, val_portion=val_portion, test_portion=test_portion) # Calculate mean and std from training data. train_mean = np.mean(train, axis=0) train_std = np.std(train, axis=0) mean = train_mean if mean_adjustment else 0 std = train_std if normalization else 1 train_mean = torch.tensor(train_mean, device=device, dtype=torch.float32) train_std = torch.tensor(train_std, device=device, dtype=torch.float32) # Calculate total variance. full_mean = np.mean(lopge, axis=0, keepdims=True) total_variance = np.mean(np.sum((lopge - full_mean)**2, axis=1)) # Create datasets. train_set = RNASeq(train, mean=mean, std=std, device=device) val_set = RNASeq(val, mean=mean, std=std, device=device) test_set = RNASeq(test, mean=mean, std=std, device=device) # Create data loaders. random_sampler = RandomSampler(train_set, replacement=True) batch_sampler = BatchSampler(random_sampler, batch_size=512, drop_last=True) train_loader = DataLoader(train_set, batch_sampler=batch_sampler) val_loader = DataLoader(val_set, batch_size=len(val)) test_loader = DataLoader(test_set, batch_size=len(test)) return (train_loader, val_loader, test_loader, train_mean, train_std, total_variance)
def eval_model(self, gen, batch_size): n = len(gen.playids) batches = BatchSampler(SequentialSampler(range(n)), batch_size=batch_size, drop_last=False) l = 0 for batch in batches: xbatch, ybatch = gen.get_features(batch) self.eval() loss = self.compute_loss(xbatch, ybatch) l += len(batch)*loss l /= n return l
def train_epochs(self, nepochs, gen): for _ in range(nepochs): len(gen.playids) batches = BatchSampler(RandomSampler(range(len(gen.playids))), batch_size=self.batch_size, drop_last=False) for batch in batches: self.train() xbatch, ybatch = gen.get_features(batch) loss = self.compute_loss(xbatch, ybatch) self.optimizer.zero_grad() loss.backward() self.optimizer.step()
def __init__(self, sampler, batch_size, drop_last, sort_key, bucket_size_multiplier=100): super().__init__(sampler, batch_size, drop_last) self.sort_key = sort_key self.bucket_sampler = BatchSampler( sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False)
def test_index_batch_sampler(tmpdir): """Test `IndexBatchSampler` properly extracts indices.""" dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices
def __init__(self, dataset, batch_size): super().__init__(None) self.epoch = 0 self.dataset = dataset self.batch_size = batch_size # Assume that data in dataset is sorted w.r. to duration self.batches = list( BatchSampler(SequentialSampler(self.dataset), batch_size=self.batch_size, drop_last=False))
def train_valid_loader(self, mnist_dataset): batch_size = 100 train_indices, valid_indices = mnist_dataset.get_train_and_validation_set_indices( train_valid_split_ratio=0.8, seed=2) train_loader = DatasetLoader( mnist_dataset, batch_sampler=BatchSampler( sampler=SubsetRandomSampler(train_indices), batch_size=batch_size, drop_last=False), collate_fn=DatasetLoader.square_matrix_collate_fn) valid_loader = DatasetLoader( mnist_dataset, batch_sampler=BatchSampler( sampler=SubsetRandomSampler(valid_indices), batch_size=batch_size, drop_last=False), collate_fn=DatasetLoader.square_matrix_collate_fn) yield train_loader, valid_loader
def __init__(self, X, Y, kern, minibatch_size=None, n_filters=256, name: str = None): super(ConvNet, self).__init__(name=name) if not hasattr(kern, 'W_'): # Create W_ and b_ as attributes in kernel X_zeros = np.zeros([1] + kern.input_shape) _ = kern.equivalent_BNN( X=torch.zeros([1] + kern.input_shape), n_samples=1, n_filters=n_filters) self._kern = kern # Make MiniBatches if necessary if minibatch_size is None: self.train_inputs = X self.train_targets = Y self.scale_factor = 1. else: self.train_inputs = torch.Tensor(list(BatchSampler(SequentialSampler(X), batch_size=minibatch_siz, drop_last=False))) self.train_targets = torch.Tensor(list(BatchSampler(SequentialSampler(Y), batch_size=minibatch_siz, drop_last=False))) self.scale_factor = X.shape[0] / minibatch_size self.n_labels = int(np.max(Y)+1) # Create GPFlow parameters with the relevant size of the network Ws, bs = [], [] for i, (W, b) in enumerate(zip(kern._W, kern._b)): if i == kern.n_layers: W_shape = [int(W.shape[1]), self.n_labels] b_shape = [self.n_labels] else: W_shape = list(map(int, W.shape[1:])) b_shape = [n_filters] W_var = kern.var_weight.read_value()/W_shape[-2] b_var = kern.var_bias.read_value() W_init = np.sqrt(W_var) * np.random.randn(*W_shape) b_init = np.sqrt(b_var) * np.random.randn(*b_shape) Ws.append(W_init) #, prior=ZeroMeanGauss(W_var))) bs.append(b_init) #, prior=ZeroMeanGauss(b_var))) # self.Ws = gpflow.params.ParamList(Ws) # self.bs = gpflow.params.ParamList(bs) self.register_parameter(name='Ws',parameter=torch.nn.Parameter(Ws),prior=None) self.register_parameter(name='bs',parameter=torch.nn.Parameter(bs),prior=None)
def __init__( self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, ): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError("timeout option should be non-negative") if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError("batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last") self.batch_size = None self.drop_last = None if sampler is not None and shuffle: raise ValueError("sampler option is mutually exclusive with " "shuffle") if self.num_workers < 0: raise ValueError("num_workers option cannot be negative; " "use num_workers=0 to disable multiprocessing.") if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True
def __init__(self, config, lmdb_path, train=True): self.config = config transform = transforms.Compose([transforms.ToTensor()]) augmentation_methods = [] if train: if self.config.noise_augmentation: augmentation_methods.append(augmentation.add_noise) if self.config.contrast_augmentation: augmentation_methods.append(augmentation.change_contrast) if self.config.pose_mean is not None: pose_label_transform = self.normalize_pose_labels else: pose_label_transform = None self._dataset = LMDB(config, lmdb_path, transform, pose_label_transform, augmentation_methods) if not train and config.distributed: batch_size = 1 else: batch_size = config.batch_size if config.distributed: self._sampler = DistributedSampler(self._dataset) if train: self._sampler = BatchSampler(self._sampler, batch_size, drop_last=True) super(LMDBDataLoader, self).__init__( self._dataset, batch_sampler=self._sampler, pin_memory=config.pin_memory, num_workers=config.workers, collate_fn=collate_fn, ) else: super(LMDBDataLoader, self).__init__( self._dataset, batch_size=batch_size, shuffle=train, pin_memory=config.pin_memory, num_workers=config.workers, drop_last=True, collate_fn=collate_fn, )
def fit(self, zeta, xu, nb_iter=100, batch_size=None, lr=1.e-3): self.optim = Adam(self.parameters(), lr=lr) batch_size = xu.shape[0] if batch_size is None else batch_size batches = list(BatchSampler(SubsetRandomSampler(range(xu.shape[0])), batch_size, False)) for n in range(nb_iter): for batch in batches: self.optim.zero_grad() loss = - self.elbo(zeta[batch], xu[batch]) loss.backward() self.optim.step()
def return_dataloaders(self, batch_size, num_workers = 0): #Perhaps rewrite this using return_dataloader method from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, RandomSampler if self.reweighter: N_targets = len(self.targets.split(', ')) def collate(batch): batch = Batch.from_data_list(batch[0]) batch.weight = self.reweighter(batch)#torch.tensor(self.reweighter.predict_weights(batch.y.view(-1,N_targets))).view(-1,1) return batch else: def collate(batch): return Batch.from_data_list(batch[0]) train_loader = DataLoader(dataset = self.train(), collate_fn = collate, num_workers = num_workers, # persistent_workers=True, pin_memory = True, sampler = BatchSampler(RandomSampler(self.train()), batch_size=batch_size, drop_last=False)) test_loader = DataLoader(dataset = self.test(extra_targets = ', energy_log10'), collate_fn = collate, num_workers = num_workers, # persistent_workers=True, pin_memory = True, sampler = BatchSampler(SequentialSampler(self.test(extra_targets = ', energy_log10')), batch_size=batch_size, drop_last=False)) val_loader = DataLoader(dataset = self.val(), collate_fn = collate, num_workers = num_workers, # persistent_workers=True, pin_memory = True, sampler = BatchSampler(RandomSampler(self.val()), batch_size=batch_size, drop_last=False)) return train_loader, test_loader, val_loader
def train(self, memory: Memory): # train over training data obs, action, reward, done = memory.get('obs', 'action', 'reward', 'done') old_mean, old_logstd, old_value = memory.get('mean', 'logstd', 'value') # compute advantages advantage = generalized_advantage_estimation(reward, old_value, done) advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) # compute target values target_value = discounted_reward(reward, done) policy_losses = torch.zeros(self.NUM_TRAIN_EPOCHS) entropies = torch.zeros(self.NUM_TRAIN_EPOCHS) value_losses = torch.zeros(self.NUM_TRAIN_EPOCHS) for i_epoch in range(self.NUM_TRAIN_EPOCHS): for indices in BatchSampler(SubsetRandomSampler(range( obs.shape[0])), self.TRAIN_BATCH_SIZE, drop_last=True): new_mean, new_logstd, new_value = self.model(obs[indices]) policy_loss, entropy = ppo_loss(old_mean[indices], old_logstd[indices], new_mean, new_logstd, action[indices], advantage[indices]) value_loss = (target_value[indices] - new_value).pow(2).mean() loss = policy_loss - self.ENTROPY_COEF * entropy + value_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() policy_losses[i_epoch] += policy_loss entropies[i_epoch] += entropy value_losses[i_epoch] += value_loss print(f'{self.world_rank} {self.rank} Rewards {reward.sum().item()}', flush=True) print( f'{self.world_rank} {self.rank} Policy {policy_losses.detach().numpy()}', flush=True) print( f'{self.world_rank} {self.rank} Value {value_losses.detach().numpy()}', flush=True) print( f'{self.world_rank} {self.rank} Entropy {entropies.detach().numpy()}', flush=True)
def train(self, batch_size=6): # Unroll rewards rewards = np.array(self.rewards) reward = 0 for i in reversed(range(len(self.rewards))): rewards[i] += self.gamma * reward reward = rewards[i] states = torch.tensor(self.states, dtype=torch.float) actions = torch.tensor(self.actions, dtype=torch.long).view(-1, 1) rewards = rewards.reshape(-1, 1) losses = [] entropies = [] # Deprecated (LR is now fixed within the net) if self._adapt_lr_on_ep_len: self.episode_lengths.append(len(self.rewards)) # calcular LR avg_length_window = np.mean(self.episode_lengths[-100:]) exp = -0.02 * avg_length_window - 2 learning_rate = 10**exp else: learning_rate = self.lr for batch in BatchSampler(SubsetRandomSampler(range(len(self.states))), batch_size, drop_last=False): states_batch = states[batch].numpy() actions_batch = actions[batch].numpy() rewards_batch = rewards[batch] loss = [] entropy = [] for state, action, reward in zip(states_batch, actions_batch, rewards_batch): state = state.reshape((-1, 1)) probs = self.net.forward(state) probs = probs.squeeze() + 1e-8 entropy.append(-np.sum(np.log(probs) * probs)) action_prob = probs[action] loss.append(-(np.log(action_prob) * reward)) dLoss = 1 / action_prob * reward self.net.backward(dLoss, action) losses.append(np.mean(loss)) entropies.append(np.mean(entropy)) self.net.mean_grads(batch_size) self.net.update() self._clear_buffers() return np.mean(losses), np.mean(entropies), self.lr