def train(self, batch_num): args = self.param.args dm = self.param.volatile.dm datakey = 'train' for i in range(batch_num): self.now_batch += 1 incoming = self.get_next_batch(dm, datakey) incoming.args = Storage() incoming.now_epoch = self.now_epoch if (i + 1) % args.batch_num_per_gradient == 0: self.zero_grad() self.net.forward(incoming) loss = incoming.result.loss self.trainSummary(self.now_batch, storage_to_list(incoming.result)) logging.info("batch %d : gen loss=%f", self.now_batch, loss.detach().cpu().numpy()) loss.backward() if (i + 1) % args.batch_num_per_gradient == 0: nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip) self.optimizer.step()
def train(self, batch_num): args = self.param.args dm = self.param.volatile.dm datakey = 'train' for i in range(batch_num): self.now_batch += 1 incoming = self.get_next_batch(dm, datakey) incoming.args = Storage() if (i + 1) % args.batch_num_per_gradient == 0: self.zero_grad() self.net.forward(incoming) loss = incoming.result.loss accuracy = np.mean( (incoming.result.label == incoming.result.prediction ).float().detach().cpu().numpy()) detail_arr = storage_to_list(incoming.result) detail_arr.update({'accuracy_on_batch': accuracy}) self.trainSummary(self.now_batch, detail_arr) logging.info("batch %d : classification loss=%f, batch accuracy=%f", \ self.now_batch, loss.detach().cpu().numpy(), accuracy) loss.backward() if (i + 1) % args.batch_num_per_gradient == 0: nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip) self.optimizer.step()
def train(self, batch_num, total_step_counter): args = self.param.args dm = self.param.volatile.dm datakey = 'train' for i in range(batch_num): self.now_batch += 1 incoming = self.get_next_batch(dm, datakey) incoming.args = Storage() incoming.args.sampling_proba = 1. - \ inverse_sigmoid_decay(args.decay_factor, total_step_counter) if (i + 1) % args.batch_num_per_gradient == 0: self.zero_grad() self.net.forward(incoming) loss = incoming.result.loss self.trainSummary(self.now_batch, storage_to_list(incoming.result)) logging.info("batch %d : gen loss=%f", self.now_batch, loss.detach().cpu().numpy()) loss.backward() if (i + 1) % args.batch_num_per_gradient == 0: nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip) self.optimizer.step() total_step_counter += 1 return total_step_counter