Exemplo n.º 1
0
Arquivo: models.py Projeto: yyht/VLAE
    def importance_sample(self, x):
        with torch.enable_grad():
            q_z_x, _ = self.encoder.forward(x)
            mu_svi = q_z_x.mu
            logvar_svi = q_z_x.logvar

            for i in range(self.n_svi_step):
                q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi)
                z = q_z_x.sample()
                p_x_z = self.decoder.forward(z)
                loss = self.loss(x, z, p_x_z, self.prior, q_z_x)
                # create_graph=True does this allow backprop through this when we update the whole thing
                mu_svi_grad, logvar_svi_grad = torch.autograd.grad(
                    loss, inputs=(mu_svi, logvar_svi), create_graph=True)

                # mu_svi_grad = utils.clip_grad(mu_svi_grad, 1)
                # logvar_svi_grad = utils.clip_grad(logvar_svi_grad, 1)
                mu_svi_grad = utils.clip_grad_norm(mu_svi_grad, 5)
                logvar_svi_grad = utils.clip_grad_norm(logvar_svi_grad, 5)
                # gradient ascent.
                mu_svi = mu_svi + self.svi_lr * mu_svi_grad
                logvar_svi = logvar_svi + self.svi_lr * logvar_svi_grad

            # obtain z_K
            q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi)
            q_z_x = q_z_x.repeat(n_importance_sample)
            z = q_z_x.sample()
            p_x_z = self.decoder(z)
            x = x.view(-1, self.image_size).unsqueeze(1).repeat(
                1, n_importance_sample, 1).view(-1, self.image_size)

            return self.importance_weighting(x, z, p_x_z, self.prior, q_z_x)
 def step(self, minibatch):
     self.optimizer.zero_grad()
     loss = self.loss(minibatch)
     loss.backward()
     utils.clip_grad_norm(self.optimizer, 100)
     self.optimizer.step()
     return loss
Exemplo n.º 3
0
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm,
          writer):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss = loss_fn(z, sldj=sldj)
            loss_meter.update(loss.item(), x.size(0))

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_meter.avg),
                      epoch)
Exemplo n.º 4
0
Arquivo: models.py Projeto: yyht/VLAE
    def forward(self, x):
        with torch.enable_grad():
            q_z_x, _ = self.encoder.forward(x)
            mu_svi = q_z_x.mu
            logvar_svi = q_z_x.logvar

            for i in range(self.n_svi_step):
                q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi)
                z = q_z_x.sample()
                p_x_z = self.decoder.forward(z)
                loss = self.loss(x, z, p_x_z, self.prior, q_z_x)
                # create_graph=True does this allow backprop through this when we update the whole thing
                mu_svi_grad, logvar_svi_grad = torch.autograd.grad(
                    loss, inputs=(mu_svi, logvar_svi), create_graph=True)

                # mu_svi_grad = utils.clip_grad(mu_svi_grad, 1)
                # logvar_svi_grad = utils.clip_grad(logvar_svi_grad, 1)
                mu_svi_grad = utils.clip_grad_norm(mu_svi_grad, 5)
                logvar_svi_grad = utils.clip_grad_norm(logvar_svi_grad, 5)

                # gradient ascent.
                mu_svi = mu_svi + self.svi_lr * mu_svi_grad
                logvar_svi = logvar_svi + self.svi_lr * logvar_svi_grad

            # obtain z_K
            q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi)
            z_K = q_z_x.sample()
            p_x_z = self.decoder.forward(z_K)

            loss = self.loss(x, z_K, p_x_z, self.prior, q_z_x)
            return loss
Exemplo n.º 5
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm, conditional):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x in trainloader:
            optimizer.zero_grad()
            if conditional:
                x, x2 = x
                x = x.to(device)
                x2 = x2.to(device)
                z, sldj = net(x, x2, reverse=False)
            else:
                x = x.to(device)
                z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
Exemplo n.º 6
0
 def _perform_train(self, X, Act, Rew, Done, Mask, n_samples):
     # clear grad
     batch_size = Act.size(0)
     seq_len = Act.size(1)
     self.optim.zero_grad()
     # forward pass
     X = Variable(X)
     P, L, V, _ = self.policy(X, self.init_h)
     L = L[:, :seq_len, :]  # logits
     P = P[:, :seq_len, :]  # remove last one
     # compute accumulative Reward
     V_data = V.data
     cur_r = V_data[:, seq_len]
     V = V[:, :seq_len]
     V_data = V_data[:, :seq_len]
     R_list = []
     for t in range(seq_len - 1, -1, -1):
         cur_r = Rew[:, t] + self.gamma * Done[:, t] * cur_r
         R_list.append(cur_r)
     R_list.reverse()
     R = torch.stack(R_list, dim=1)
     # Advantage Normalization
     Adv = (R - V_data) * Mask  # advantage
     avg_val = Adv.sum() / n_samples
     Adv = (Adv - avg_val) * Mask  # reduce mean
     std_val = np.sqrt(torch.sum(Adv**2) / n_samples)  # standard dev
     Adv = Variable(Adv / max(std_val, 0.1))
     # critic loss
     R = Variable(R)
     Mask = Variable(Mask)
     critic_loss = torch.sum(Mask * (R - V)**2) / n_samples
     # policy gradient loss
     Act = Variable(Act)  # [batch_size, seq_len]
     Act = Act.unsqueeze(2)  # [batch_size, seq_len, 1]
     P_Act = torch.gather(P, 2, Act).squeeze(dim=2)  # [batch_size, seq_len]
     pg_loss = -torch.sum(P_Act * Adv * Mask) / n_samples
     # entropy bonus
     P_Ent = torch.sum(self.policy.entropy(L) * Mask) / n_samples
     pg_loss -= self.entropy_penalty * P_Ent
     # backprop
     loss = pg_loss + critic_loss
     loss.backward()
     L_norm = torch.sum(torch.sum(L**2, dim=-1) * Mask) / n_samples
     ret_dict = dict(pg_loss=pg_loss.data.cpu().numpy()[0],
                     policy_entropy=P_Ent.data.cpu().numpy()[0],
                     critic_loss=critic_loss.data.cpu().numpy()[0],
                     logits_norm=L_norm.data.cpu().numpy()[0])
     # gradient clip
     utils.clip_grad_norm(self.policy.parameters(), self.grad_clip)
     # apply SGD step
     self.optim.step()
     return ret_dict
Exemplo n.º 7
0
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm,
          writer):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_unsup_meter = utils.AverageMeter()
    loss_nll_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, y in trainloader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()

            logits = loss_fn.prior.class_logits(z.reshape((len(z), -1)))
            loss_nll = F.cross_entropy(logits, y)

            loss_unsup = loss_fn(z, sldj=sldj)
            loss = loss_nll + loss_unsup

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            preds = torch.argmax(logits, dim=1)
            acc = (preds == y).float().mean().item()

            acc_meter.update(acc, x.size(0))
            loss_meter.update(loss.item(), x.size(0))
            loss_unsup_meter.update(loss_unsup.item(), x.size(0))
            loss_nll_meter.update(loss_nll.item(), x.size(0))

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(
                                         x, loss_unsup_meter.avg),
                                     acc=acc_meter.avg)
            progress_bar.update(x.size(0))
    x_img = torchvision.utils.make_grid(x[:10],
                                        nrow=2,
                                        padding=2,
                                        pad_value=255)
    writer.add_image("data/x", x_img)
    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, epoch)
    writer.add_scalar("train/loss_nll", loss_nll_meter.avg, epoch)
    writer.add_scalar("train/acc", acc_meter.avg, epoch)
    writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_unsup_meter.avg),
                      epoch)
Exemplo n.º 8
0
Arquivo: models.py Projeto: yyht/VLAE
    def write_summary(self, x, writer, epoch):
        q_z_x, _ = self.encoder.forward(x)
        mu_svi = q_z_x.mu
        logvar_svi = q_z_x.logvar

        for i in range(self.n_svi_step):
            q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi)
            z = q_z_x.sample()
            p_x_z = self.decoder.forward(z)
            loss = self.loss(x, z, p_x_z, self.prior, q_z_x)
            # create_graph=True does this allow backprop through this when we update the whole thing
            mu_svi_grad, logvar_svi_grad = torch.autograd.grad(
                loss, inputs=(mu_svi, logvar_svi), create_graph=True)

            # mu_svi_grad = utils.clip_grad(mu_svi_grad, 1)
            # logvar_svi_grad = utils.clip_grad(logvar_svi_grad, 1)
            mu_svi_grad = utils.clip_grad_norm(mu_svi_grad, 5)
            logvar_svi_grad = utils.clip_grad_norm(logvar_svi_grad, 5)
            # gradient ascent.
            mu_svi = mu_svi + self.svi_lr * mu_svi_grad
            logvar_svi = logvar_svi + self.svi_lr * logvar_svi_grad

        # obtain z_K
        q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi)
        z_K = q_z_x.sample()
        p_x_z = self.decoder.forward(z_K)

        writer.add_scalar(
            'kl_div',
            torch.mean(-self.prior.log_probability(z_K) +
                       q_z_x.log_probability(z_K)).item(), epoch)
        writer.add_scalar('recon_error',
                          -torch.mean(p_x_z.log_probability(x)).item(), epoch)
        writer.add_image('data',
                         vutils.make_grid(self.dataset.unpreprocess(x)), epoch)
        writer.add_image(
            'reconstruction_z',
            vutils.make_grid(self.dataset.unpreprocess(p_x_z.mu).clamp(0, 1)),
            epoch)

        sample = torch.randn(len(x), z.shape[1]).cuda()
        sample = self.decoder(sample).mu
        writer.add_image(
            'generated',
            vutils.make_grid(self.dataset.unpreprocess(sample).clamp(0, 1)),
            epoch)
Exemplo n.º 9
0
def train(epoch,
          net,
          trainloader,
          device,
          optimizer,
          loss_fn,
          max_grad_norm,
          writer,
          num_samples=10,
          sampling=True,
          tb_freq=100):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    iter_count = 0
    batch_count = 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            iter_count += 1
            batch_count += x.size(0)
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss = loss_fn(z, sldj=sldj)
            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

            if iter_count % tb_freq == 0 or batch_count == len(
                    trainloader.dataset):
                tb_step = epoch * (len(trainloader.dataset)) + batch_count
                writer.add_scalar("train/loss", loss_meter.avg, tb_step)
                writer.add_scalar("train/bpd",
                                  utils.bits_per_dim(x, loss_meter.avg),
                                  tb_step)
                if sampling:
                    net.eval()
                    draw_samples(net, writer, loss_fn, num_samples, device,
                                 tuple(x[0].shape), tb_step)
                    net.train()
Exemplo n.º 10
0
Arquivo: train.py Projeto: yyht/daga
def train(model, optimizer, train_iter, epoch, args):
    """Train with mini-batches."""
    model.train()
    total_stats = Statistics()
    batch_stats = Statistics()
    num_batches = len(train_iter)
    for i, batch in enumerate(train_iter):
        if args.warmup > 0:
            args.beta = min(1, args.beta + 1.0 / (args.warmup * num_batches))

        sents = batch.sent
        loss, stats = model(sents, args.beta)
        optimizer.zero_grad()
        loss.backward()
        utils.clip_grad_norm(optimizer, args)
        optimizer.step()
        total_stats.update(stats)
        batch_stats.update(stats)
        batch_stats = report_batch(batch_stats, epoch, i, num_batches, args)
        torch.cuda.empty_cache()
    return total_stats
Exemplo n.º 11
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * self.args['episode_len']):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, full_act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        full_act_n = Variable(torch.from_numpy(full_act)).type(FloatTensor)
        rew_n = Variable(torch.from_numpy(rew),
                         volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done),
                          volatile=True).type(FloatTensor)

        time_counter[0] += time.time() - tt
        tt = time.time()

        # train q network
        common.debugger.print('Grad Stats of Q Update ...', False)
        target_act_next = self.target_p(obs_next_n)
        target_q_next = self.target_q(obs_next_n, target_act_next)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile = False
        current_q = self.q(obs_n, full_act_n)
        q_norm = (current_q * current_q).mean().squeeze()  # l2 norm
        q_loss = F.smooth_l1_loss(
            current_q,
            target_q) + self.args['critic_penalty'] * q_norm  # huber

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)

        self.q_optim.zero_grad()
        q_loss.backward()

        common.debugger.print('Stats of Q Network (*before* clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.q)

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
        self.q_optim.step()

        # train p network
        new_act_n = self.p(obs_n)  # NOTE: maybe use <gumbel_noise=None> ?
        q_val = self.q(obs_n, new_act_n)
        p_loss = -q_val.mean().squeeze()
        p_ent = self.p.entropy().mean().squeeze()
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration

        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)

        self.p_optim.zero_grad()
        self.q_optim.zero_grad()  # important!! clear the grad in Q
        p_loss.backward()

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
        self.p_optim.step()

        common.debugger.print(
            'Stats of Q Network (in the phase of P-Update)....', False)
        utils.log_parameter_stats(common.debugger, self.q)
        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.p)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.p, self.target_p, rate=self.target_update_rate)
        make_update_exp(self.q, self.target_q, rate=self.target_update_rate)

        common.debugger.print('Stats of Q Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_q)
        common.debugger.print('Stats of P Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_p)

        time_counter[2] += time.time() - tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Exemplo n.º 12
0
    def update(self,
               obs,
               init_hidden,
               act,
               rew,
               done,
               target=None,
               supervision_mask=None,
               mask_input=None,
               return_kl_divergence=True):
        """
        :param obs:  list of list of [dims]...
        :param init_hidden: list of [layer, 1, units]
        :param act: [batch, seq_len]
        :param rew: [batch, seq_len]
        :param done: [batch, seq_len]
        :param target: [batch, seq_len, n_instruction] or None (when single-target)
        :param supervision_mask: timesteps marked with supervised learning loss [batch, seq_len] or None (pure RL)
        """
        tt = time.time()

        # reward clipping
        if self.rew_clip is not None:
            rew = np.clip(rew, -self.rew_clip, self.rew_clip)

        # convert data to Variables
        obs = self._create_gpu_tensor(
            obs, return_variable=True)  # [batch, t_max+1, dims...]
        init_hidden = self._create_gpu_hidden(
            init_hidden, return_variable=True)  # [layers, batch, units]
        if target is not None:
            target = self._create_target_tensor(target, return_variable=True)
        if mask_input is not None:
            mask_input = self._create_feature_tensor(mask_input,
                                                     return_variable=True)
        act = Variable(
            torch.from_numpy(act).type(LongTensor))  # [batch, t_max]
        mask = 1.0 - torch.from_numpy(done).type(FloatTensor)  # [batch, t_max]
        mask_var = Variable(mask)
        sup_mask = None if supervision_mask is None else torch.from_numpy(
            supervision_mask).type(ByteTensor)  # [batch, t_max]

        time_counter[0] += time.time() - tt

        batch_size = self.batch_size
        t_max = self.t_max
        gamma = self.gamma

        tt = time.time()

        if self.accu_grad_steps == 0:  # clear grad
            self.optim.zero_grad()

        # forward pass
        logits = []
        logprobs = []
        values = []
        t_obs_slices = torch.chunk(obs, t_max + 1, dim=1)
        obs_slices = [t.contiguous() for t in t_obs_slices]
        if target is not None:
            t_target_slices = torch.chunk(target, t_max + 1, dim=1)
            target_slices = [t.contiguous() for t in t_target_slices]
        if mask_input is not None:
            t_mask_input_slices = torch.chunk(mask_input, t_max + 1, dim=1)
            mask_input_slices = [m.contiguous() for m in t_mask_input_slices]
        cur_h = init_hidden
        for t in range(t_max):
            #cur_obs = obs[:, t:t+1, ...].contiguous()
            cur_obs = obs_slices[t]
            t_target = None if target is None else target_slices[t]
            t_mask = None if mask_input is None else mask_input_slices[t]
            cur_logp, cur_val, nxt_h = self.policy(cur_obs,
                                                   cur_h,
                                                   target=t_target,
                                                   extra_input_feature=t_mask)
            cur_h = self.policy.mark_hidden_states(nxt_h, mask_var[:, t:t + 1])
            values.append(cur_val)
            logprobs.append(cur_logp)
            logits.append(self.policy.logits)
        #cur_obs = obs[:, t_max:t_max + 1, ...].contiguous()
        cur_obs = obs_slices[-1]
        t_target = None if target is None else target_slices[-1]
        t_mask = None if mask_input is None else mask_input_slices[-1]
        nxt_val = self.policy(cur_obs,
                              cur_h,
                              only_value=True,
                              return_tensor=True,
                              target=t_target,
                              extra_input_feature=t_mask)
        V = torch.cat(values, dim=1)  # [batch, t_max]
        P = torch.cat(logprobs, dim=1)  # [batch, t_max, n_act]
        L = torch.cat(logits, dim=1)
        p_ent = torch.mean(self.policy.entropy(L))  # compute entropy
        #L_norm = torch.mean(torch.norm(L, dim=-1))
        L_norm = torch.mean(torch.sum(L * L, dim=-1))  # L^2 penalty

        # estimate accumulative rewards
        rew = torch.from_numpy(rew).type(FloatTensor)  # [batch, t_max]
        R = []
        cur_R = nxt_val.squeeze()  # [batch]
        for t in range(t_max - 1, -1, -1):
            cur_mask = mask[:, t]
            cur_R = rew[:, t] + gamma * cur_R * cur_mask
            R.append(cur_R)
        R.reverse()
        R = Variable(torch.stack(R, dim=1))  # [batch, t_max]

        # estimate advantage
        A_dat = R.data - V.data  # stop gradient here
        std_val = None
        if self.adv_norm:  # perform advantage normalization
            std_val = max(A_dat.std(), 0.1)
            A_dat = (A_dat - A_dat.mean()) / (std_val + 1e-10)
        if sup_mask is not None:  # supervision
            A_dat[
                sup_mask >
                0] = 1.0  # change A * log P(a) to log P(supervised_a), act has been modified in zmq_util
        A = Variable(A_dat)
        # [optional]  A = Variable(rew) - V

        # compute loss
        #critic_loss = F.smooth_l1_loss(V, R)
        critic_loss = torch.mean((R - V)**2)
        pg_loss = -torch.mean(self.policy.logprob(act, P) * A)
        if self.args['entropy_penalty'] is not None:
            pg_loss -= self.args[
                'entropy_penalty'] * p_ent  # encourage exploration
        loss = self.q_loss_coef * critic_loss + pg_loss
        if self.logit_loss_coef is not None:
            loss += self.logit_loss_coef * L_norm

        # backprop
        if self.grad_batch > 1:
            loss = loss / float(self.grad_batch)
        loss.backward()

        ret_dict = dict(pg_loss=pg_loss.data.cpu().numpy()[0],
                        policy_entropy=p_ent.data.cpu().numpy()[0],
                        critic_loss=critic_loss.data.cpu().numpy()[0],
                        logits_norm=L_norm.data.cpu().numpy()[0])
        if std_val is not None:
            ret_dict['adv_norm'] = std_val

        if self.accu_grad_steps == 0:
            self.accu_ret_dict = ret_dict
        else:
            for k in ret_dict:
                self.accu_ret_dict[k] += ret_dict[k]

        self.accu_grad_steps += 1
        if self.accu_grad_steps < self.grad_batch:  # do not update parameter now
            time_counter[1] += time.time() - tt
            return None

        # update parameters
        for k in self.accu_ret_dict:
            self.accu_ret_dict[k] /= self.grad_batch
        ret_dict = self.accu_ret_dict
        self.accu_grad_steps = 0

        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip)
        self.optim.step()

        if return_kl_divergence:
            cur_h = init_hidden
            new_logprobs = []
            for t in range(t_max):
                # cur_obs = obs[:, t:t+1, ...].contiguous()
                cur_obs = obs_slices[t]
                t_target = target_slices[t] if self.multi_target else None
                t_mask = None if mask_input is None else mask_input_slices[t]
                cur_logp, nxt_h = self.policy(cur_obs,
                                              cur_h,
                                              return_value=False,
                                              target=t_target,
                                              extra_input_feature=t_mask)
                cur_h = self.policy.mark_hidden_states(nxt_h,
                                                       mask_var[:, t:t + 1])
                new_logprobs.append(cur_logp)
            new_P = torch.cat(new_logprobs, dim=1)
            kl = self.policy.kl_divergence(new_P, P).mean().data.cpu()[0]
            ret_dict['KL(P_new||P_old)'] = kl

            if kl > flag_max_kl_diff:
                self.lrate /= flag_lrate_coef
                self.optim.__dict__['param_groups'][0]['lr'] = self.lrate
                ret_dict['!!![NOTE]:'] = (
                    '------>>>> KL is too large (%.6f), decrease lrate to %.5f'
                    % (kl, self.lrate))
            elif (kl < flag_min_kl_diff) and (self.lrate < flag_max_lrate):
                self.lrate *= flag_lrate_coef
                self.optim.__dict__['param_groups'][0]['lr'] = self.lrate
                ret_dict['!!![NOTE]:'] = (
                    '------>>>> KL is too small (%.6f), increase lrate to %.5f'
                    % (kl, self.lrate))

        time_counter[1] += time.time() - tt
        return ret_dict
Exemplo n.º 13
0
    def update(self):
        if (self.a is not None) or \
           not self.replay_buffer.can_sample(self.batch_size * 4):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, full_act, rew, msk, done, total_length = \
            self.replay_buffer.sample(self.batch_size, seq_len=self.batch_len)
        total_length = float(total_length)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        _full_obs_n = self._process_frames(
            obs, merge_dim=False,
            return_variable=False)  # [batch, seq_len+1, ...]
        batch = _full_obs_n.size(0)
        seq_len = _full_obs_n.size(1) - 1
        full_obs_n = Variable(_full_obs_n, volatile=True)
        obs_n = Variable(
            _full_obs_n[:, :-1, ...]).contiguous()  # [batch, seq_len, ...]
        obs_next_n = Variable(_full_obs_n[:, 1:, ...],
                              volatile=True).contiguous()
        img_c, img_h, img_w = obs_n.size(-3), obs_n.size(-2), obs_n.size(-1)
        packed_obs_n = obs_n.view(-1, img_c, img_h, img_w)
        packed_obs_next_n = obs_next_n.view(-1, img_c, img_h, img_w)
        full_act_n = Variable(torch.from_numpy(full_act)).type(
            FloatTensor)  # [batch, seq_len, ...]
        act_padding = Variable(
            torch.zeros(self.batch_size, 1,
                        full_act_n.size(-1))).type(FloatTensor)
        pad_act_n = torch.cat([act_padding, full_act_n],
                              dim=1)  # [batch, seq_len+1, ...]
        rew_n = Variable(torch.from_numpy(rew),
                         volatile=True).type(FloatTensor)
        msk_n = Variable(torch.from_numpy(msk)).type(
            FloatTensor)  # [batch, seq_len]
        done_n = Variable(torch.from_numpy(done)).type(
            FloatTensor)  # [batch, seq_len]

        time_counter[0] += time.time() - tt
        tt = time.time()

        # train q network
        common.debugger.print('Grad Stats of Q Update ...', False)

        full_target_act, _ = self.target_p(
            full_obs_n, act=pad_act_n)  # list([batch, seq_len+1, act_dim])
        target_act_next = torch.cat(full_target_act, dim=-1)[:, 1:, :]
        act_dim = target_act_next.size(-1)
        target_act_next = target_act_next.resize(batch * seq_len, act_dim)

        target_q_next = self.target_q(packed_obs_next_n,
                                      act=target_act_next)  #[batch * seq_len]
        target_q_next.view(batch, seq_len)
        target_q = (rew_n + self.gamma * done_n * target_q_next) * msk_n
        target_q = target_q.view(-1)
        target_q.volatile = False

        current_q = self.q(packed_obs_n, act=full_act_n.view(
            -1, act_dim)) * msk_n.view(-1)
        q_norm = (current_q * current_q).sum() / total_length  # l2 norm
        q_loss = F.smooth_l1_loss(current_q, target_q, size_average=False) / total_length \
                 + self.args['critic_penalty']*q_norm  # huber

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)

        self.q_optim.zero_grad()
        q_loss.backward()

        common.debugger.print('Stats of Q Network (*before* clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.q)

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
        self.q_optim.step()

        # train p network
        new_act_n, _ = self.p(
            obs_n, act=pad_act_n[:, :-1, :])  # [batch, seq_len, act_dim]
        new_act_n = torch.cat(new_act_n, dim=-1)
        new_act_n = new_act_n.view(-1, act_dim)
        q_val = self.q(packed_obs_n, new_act_n) * msk_n.view(-1)
        p_loss = -q_val.sum() / total_length
        p_ent = self.p.entropy(weight=msk_n).sum() / total_length
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration

        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)

        self.p_optim.zero_grad()
        self.q_optim.zero_grad()  # important!! clear the grad in Q
        p_loss.backward()

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
        self.p_optim.step()

        common.debugger.print(
            'Stats of Q Network (in the phase of P-Update)....', False)
        utils.log_parameter_stats(common.debugger, self.q)
        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.p)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.p, self.target_p, rate=self.target_update_rate)
        make_update_exp(self.q, self.target_q, rate=self.target_update_rate)

        common.debugger.print('Stats of Q Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_q)
        common.debugger.print('Stats of P Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_p)

        time_counter[2] += time.time() - tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
def train(epoch, net, trainloader, ood_loader, device, optimizer, loss_fn, 
         max_grad_norm, writer, negative_val=-1e5, num_samples=10, tb_freq=100):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_positive_meter = utils.AverageMeter()
    loss_negative_meter = utils.AverageMeter()
    iter_count = 0
    batch_count = 0
    pooler = MedianPool2d(7, padding=3)
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for (x, _), (x_transposed, _) in zip(trainloader, ood_loader):

            bs = x.shape[0]
            x = torch.cat((x, x_transposed), dim=0)
            iter_count += 1
            batch_count += bs
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss = loss_fn(z, sldj=sldj, mean=False)
            loss[bs:] *= (-1)
            loss_positive = loss[:bs]
            loss_negative = loss[bs:]
            if (loss_negative > negative_val).sum() > 0:
                loss_negative = loss_negative[loss_negative > negative_val]
                loss_negative = loss_negative.mean()
                loss_positive = loss_positive.mean()
                loss = 0.5*(loss_positive + loss_negative)
            else:
                loss_negative = torch.tensor(0.)
                loss_positive = loss_positive.mean()
                loss = loss_positive
            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            loss_meter.update(loss.item(), bs)
            loss_positive_meter.update(loss_positive.item(), bs)
            loss_negative_meter.update(loss_negative.item(), bs)
            progress_bar.set_postfix(
                pos_bpd=utils.bits_per_dim(x[:bs], loss_positive_meter.avg),
                neg_bpd=utils.bits_per_dim(x[bs:], -loss_negative_meter.avg),
                neg_loss=loss_negative.mean().item())
            progress_bar.update(bs)

            if iter_count % tb_freq == 0 or batch_count == len(trainloader.dataset):
                tb_step = epoch*(len(trainloader.dataset))+batch_count
                writer.add_scalar("train/loss", loss_meter.avg, tb_step)
                writer.add_scalar("train/loss_positive", loss_positive_meter.avg, tb_step)
                writer.add_scalar("train/loss_negative", loss_negative_meter.avg, tb_step)
                writer.add_scalar("train/bpd_positive", utils.bits_per_dim(x[:bs], loss_positive_meter.avg), tb_step)
                writer.add_scalar("train/bpd_negative", utils.bits_per_dim(x[bs:], -loss_negative_meter.avg), tb_step)
                x1_img = torchvision.utils.make_grid(x[:10], nrow=2 , padding=2, pad_value=255)
                x2_img = torchvision.utils.make_grid(x[-10:], nrow=2 , padding=2, pad_value=255)
                writer.add_image("data/x", x1_img)
                writer.add_image("data/x_transposed", x2_img)
                net.eval()
                draw_samples(net, writer, loss_fn, num_samples, device, tuple(x[0].shape), tb_step)
                net.train()
Exemplo n.º 15
0
    def update(self, cpu_batch, gpu_batch):

        #print('[elf_ddpg] update!!!!')
        self.update_counter += 1
        self.train()
        tt = time.time()

        obs_n, obs_next_n, full_act_n, rew_n, done_n = self._process_elf_frames(
            gpu_batch, keep_time=False)  # collapse all the samples
        obs_n = (obs_n.type(FloatTensor) - 128.0) / 256.0
        obs_n = Variable(obs_n)
        obs_next_n = (obs_next_n.type(FloatTensor) - 128.0) / 256.0
        obs_next_n = Variable(obs_next_n, volatile=True)
        full_act_n = Variable(full_act_n)
        rew_n = Variable(rew_n, volatile=True)
        done_n = Variable(done_n, volatile=True)

        self.sample_counter += obs_n.size(0)

        time_counter[0] += time.time() - tt

        #print('[elf_ddpg] data loaded!!!!!')

        tt = time.time()

        self.optim.zero_grad()

        # train p network
        q_val = self.net(obs_n, action=None, output_critic=True)
        p_loss = -q_val.mean().squeeze()
        p_ent = self.net.entropy().mean().squeeze()
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration
        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)
        p_loss.backward()
        self.net.clear_critic_specific_grad(
        )  # we do not need to compute q_grad for actor!!!

        # train q network
        common.debugger.print('Grad Stats of Q Update ...', False)
        target_q_next = self.target_net(obs_next_n, output_critic=True)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile = False
        current_q = self.net(obs_n, action=full_act_n, output_critic=True)
        q_norm = (current_q * current_q).mean().squeeze()  # l2 norm
        q_loss = F.smooth_l1_loss(
            current_q,
            target_q) + self.args['critic_penalty'] * q_norm  # huber
        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)
        q_loss = q_loss * self.q_loss_coef
        q_loss.backward()

        # total_loss = q_loss + p_loss
        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()

        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.net,
                        self.target_net,
                        rate=self.target_update_rate)

        common.debugger.print('Stats of Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time() - tt

        stats = dict(policy_loss=p_loss.data.cpu().numpy()[0],
                     policy_entropy=p_ent.data.cpu().numpy()[0],
                     critic_norm=q_norm.data.cpu().numpy()[0],
                     critic_loss=q_loss.data.cpu().numpy()[0] /
                     self.q_loss_coef,
                     eplen=cpu_batch[-1]['stats_eplen'].mean(),
                     avg_rew=cpu_batch[-1]['stats_rew'].mean())
        self.print_log(stats)
Exemplo n.º 16
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * min(self.args['update_freq'], 20)):
            return None
        self._update_counter += 1
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        if self.multi_target:
            target_idx = self.target_buffer[self.replay_buffer._idxes]
            targets = np.zeros((self.batch_size, common.n_target_instructions), dtype=np.uint8)
            targets[list(range(self.batch_size)), target_idx] = 1
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        act_n = Variable(torch.from_numpy(act)).type(LongTensor)
        rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor)
        if self.multi_target:
            target_n = Variable(torch.from_numpy(targets).type(FloatTensor))
        else:
            target_n = None

        time_counter[0] += time.time() - tt
        tt = time.time()

        # compute critic loss
        target_q_val_next = self.target_net(obs_next_n, only_q_value=True, target=target_n)
        # double Q learning
        target_act_next = torch.max(self.net(obs_next_n, only_q_value=True, target=target_n), dim=1, keepdim=True)[1]
        target_q_next = torch.gather(target_q_val_next, 1, target_act_next).squeeze()
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile=False
        current_q_val = self.net(obs_n, only_q_value=True, target=target_n)
        current_q = torch.gather(current_q_val, 1, act_n.view(-1, 1)).squeeze()
        q_norm = (current_q * current_q).mean().squeeze()
        q_loss = F.smooth_l1_loss(current_q, target_q)

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False)
        common.debugger.print('>> Q_Norm = {}'.format(q_norm.data.mean()), False)

        total_loss = q_loss.mean()
        if self.args['critic_penalty'] > 1e-10:
            total_loss += self.args['critic_penalty']*q_norm

        # compute gradient
        self.optim.zero_grad()
        #autograd.backward([total_loss, current_act], [torch.ones(1), None])
        total_loss.backward()
        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()
        common.debugger.print('Stats of Model (*after* clip and opt)....', False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() -tt
        tt =time.time()

        # update target networks
        if self.target_net_update_freq is not None:
            if self._update_counter == self.target_net_update_freq:
                self._update_counter = 0
                self.target_net.load_state_dict(self.net.state_dict())
        else:
            make_update_exp(self.net, self.target_net, rate=self.target_update_rate)
        common.debugger.print('Stats of Target Network (After Update)....', False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time()-tt

        return dict(critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Exemplo n.º 17
0
    def update(self,
               obs,
               init_hidden,
               act,
               rew,
               done,
               target=None,
               aux_target=None,
               return_kl_divergence=True):
        """
        :param obs:  list of list of [dims]...
        :param init_hidden: list of [layer, 1, units]
        :param act: [batch, seq_len]
        :param rew: [batch, seq_len]
        :param done: [batch, seq_len]
        :param target: [batch, seq_len, n_instruction] or None (when single-target)
        :param aux_target: 0/1 label matrix [batch, seq_len, n_aux_pred] or None (not updating the aux-loss)
        """
        assert (aux_target
                is not None), 'AuxTrainer must be given <aux_target>'
        tt = time.time()

        # reward clipping
        rew = np.clip(rew, -1, 1)

        # convert data to Variables
        obs = self._create_gpu_tensor(
            obs, return_variable=True)  # [batch, t_max+1, dims...]
        init_hidden = self._create_gpu_hidden(
            init_hidden, return_variable=True)  # [layers, batch, units]
        if target is not None:
            target = self._create_target_tensor(target, return_variable=True)
        aux_target = self._create_aux_target_tensor(aux_target)
        act = Variable(
            torch.from_numpy(act).type(LongTensor))  # [batch, t_max]
        mask = 1.0 - torch.from_numpy(done).type(FloatTensor)  # [batch, t_max]
        mask_var = Variable(mask)

        time_counter[0] += time.time() - tt

        batch_size = self.batch_size
        t_max = self.t_max
        gamma = self.gamma

        tt = time.time()

        self.optim.zero_grad()

        # forward pass
        logits = []
        logprobs = []
        values = []
        aux_preds = []
        obs = obs
        t_obs_slices = torch.chunk(obs, t_max + 1, dim=1)
        obs_slices = [t.contiguous() for t in t_obs_slices]
        if target is not None:
            t_target_slices = torch.chunk(target, t_max + 1, dim=1)
            target_slices = [t.contiguous() for t in t_target_slices]
        cur_h = init_hidden
        for t in range(t_max):
            #cur_obs = obs[:, t:t+1, ...].contiguous()
            cur_obs = obs_slices[t]
            if target is not None:
                ret_vals = self.policy(
                    cur_obs,
                    cur_h,
                    target=target_slices[t],
                    compute_aux_pred=True,
                    return_aux_logprob=self.use_supervised_loss)
            else:
                ret_vals = self.policy(
                    cur_obs,
                    cur_h,
                    compute_aux_pred=True,
                    return_aux_logprob=self.use_supervised_loss)
            cur_logp, cur_val, nxt_h, aux_p = ret_vals
            cur_h = self.policy.mark_hidden_states(nxt_h, mask_var[:, t:t + 1])
            values.append(cur_val)
            logprobs.append(cur_logp)
            logits.append(self.policy.logits)
            aux_preds.append(aux_p)
        #cur_obs = obs[:, t_max:t_max + 1, ...].contiguous()
        cur_obs = obs_slices[-1]
        if target is not None:
            nxt_val = self.policy(cur_obs,
                                  cur_h,
                                  only_value=True,
                                  return_tensor=True,
                                  target=target_slices[-1])
        else:
            nxt_val = self.policy(cur_obs,
                                  cur_h,
                                  only_value=True,
                                  return_tensor=True)
        V = torch.cat(values, dim=1)  # [batch, t_max]
        P = torch.cat(logprobs, dim=1)  # [batch, t_max, n_act]
        L = torch.cat(logits, dim=1)
        p_ent = torch.mean(self.policy.entropy(L))  # compute entropy
        Aux_P = torch.cat(aux_preds, dim=1)  # [batch, t_max, n_aux_pred]

        # estimate accumulative rewards
        rew = torch.from_numpy(rew).type(FloatTensor)  # [batch, t_max]
        R = []
        cur_R = nxt_val.squeeze()  # [batch]
        for t in range(t_max - 1, -1, -1):
            cur_mask = mask[:, t]
            cur_R = rew[:, t] + gamma * cur_R * cur_mask
            R.append(cur_R)
        R.reverse()
        R = Variable(torch.stack(R, dim=1))  # [batch, t_max]

        # estimate advantage
        A = Variable(R.data - V.data)  # stop gradient here
        # [optional]  A = Variable(rew) - V

        # compute loss
        #critic_loss = F.smooth_l1_loss(V, R)
        critic_loss = torch.mean((R - V)**2)
        pg_loss = -torch.mean(self.policy.logprob(act, P) * A)
        if self.args['entropy_penalty'] is not None:
            pg_loss -= self.args[
                'entropy_penalty'] * p_ent  # encourage exploration

        # aux task loss
        aux_loss = -(Aux_P * aux_target).sum(dim=-1).mean()

        loss = self.q_loss_coef * critic_loss + pg_loss + self.aux_loss_coef * aux_loss

        # backprop
        loss.backward()

        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip)
        self.optim.step()

        ret_dict = dict(pg_loss=pg_loss.data.cpu().numpy()[0],
                        aux_task_loss=aux_loss.data.cpu().numpy()[0],
                        policy_entropy=p_ent.data.cpu().numpy()[0],
                        critic_loss=critic_loss.data.cpu().numpy()[0])

        if return_kl_divergence:
            cur_h = init_hidden
            new_logprobs = []
            for t in range(t_max):
                # cur_obs = obs[:, t:t+1, ...].contiguous()
                cur_obs = obs_slices[t]
                if self.multi_target:
                    cur_target = target_slices[t]
                else:
                    cur_target = None
                cur_logp, nxt_h = self.policy(cur_obs,
                                              cur_h,
                                              return_value=False,
                                              target=cur_target)
                cur_h = self.policy.mark_hidden_states(nxt_h,
                                                       mask_var[:, t:t + 1])
                new_logprobs.append(cur_logp)
            new_P = torch.cat(new_logprobs, dim=1)
            kl = self.policy.kl_divergence(new_P, P).mean().data.cpu()[0]
            ret_dict['KL(P_new||P_old)'] = kl

            if kl > flag_max_kl_diff:
                self.lrate /= flag_lrate_coef
                self.optim.__dict__['param_groups'][0]['lr'] = self.lrate
                ret_dict['!!![NOTE]:'] = (
                    '------>>>> KL is too large (%.6f), decrease lrate to %.5f'
                    % (kl, self.lrate))
            elif (kl < flag_min_kl_diff) and (self.lrate < flag_max_lrate):
                self.lrate *= flag_lrate_coef
                self.optim.__dict__['param_groups'][0]['lr'] = self.lrate
                ret_dict['!!![NOTE]:'] = (
                    '------>>>> KL is too small (%.6f), increase lrate to %.5f'
                    % (kl, self.lrate))

        time_counter[1] += time.time() - tt
        return ret_dict
Exemplo n.º 18
0
def train(epoch,
          net,
          trainloader,
          device,
          optimizer,
          loss_fn,
          max_grad_norm,
          writer,
          num_samples=10,
          sampling=True,
          tb_freq=100):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_unsup_meter = utils.AverageMeter()
    loss_reconstr_meter = utils.AverageMeter()
    kl_loss_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    iter_count = 0
    batch_count = 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            iter_count += 1
            batch_count += x.size(0)
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss_unsup = loss_fn(z, sldj=sldj)

            # if vae_loss:
            #     logvar_z = -logvar_net(z)
            #     z_perturbed = z + torch.randn_like(z) * torch.exp(0.5 * logvar_z)
            #     x_reconstr = net.module.inverse(z_perturbed)
            #     if decoder_likelihood == 'binary_ce':
            #         loss_reconstr = F.binary_cross_entropy(x_reconstr, x, reduction='sum') / x.size(0)
            #     else:
            #         loss_reconstr = F.mse_loss(x_reconstr, x, reduction='sum') / x.size(0)
            #     kl_loss = -0.5 * (logvar_z - logvar_z.exp()).sum(dim=[1])
            #     kl_loss = kl_loss.mean()
            #     loss = loss_unsup + loss_reconstr * reconstr_weight + kl_loss * reconstr_weight
            # else:
            logvar_z = torch.tensor([0.])
            loss_reconstr = torch.tensor([0.])
            kl_loss = torch.tensor([0.])
            loss = loss_unsup

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            loss_unsup_meter.update(loss_unsup.item(), x.size(0))
            loss_reconstr_meter.update(loss_reconstr.item(), x.size(0))
            kl_loss_meter.update(kl_loss.item(), x.size(0))
            loss_meter.update(loss.item(), x.size(0))

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

            if iter_count % tb_freq == 0 or batch_count == len(
                    trainloader.dataset):
                tb_step = epoch * (len(trainloader.dataset)) + batch_count
                writer.add_scalar("train/loss", loss_meter.avg, tb_step)
                writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg,
                                  tb_step)
                writer.add_scalar("train/loss_reconstr",
                                  loss_reconstr_meter.avg, tb_step)
                writer.add_scalar("train/kl_loss", kl_loss_meter.avg, tb_step)
                writer.add_scalar("train/bpd",
                                  utils.bits_per_dim(x, loss_unsup_meter.avg),
                                  tb_step)
                writer.add_histogram('train/logvar_z', logvar_z, tb_step)
                if sampling:
                    net.eval()
                    draw_samples(net, writer, loss_fn, num_samples, device,
                                 tuple(x[0].shape), tb_step)
                    net.train()
Exemplo n.º 19
0
def train(
    epoch,
    net,
    trainloader,
    device,
    optimizer,
    loss_fn,
    label_weight,
    max_grad_norm,
    writer,
    use_unlab=True,
):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_unsup_meter = utils.AverageMeter()
    loss_nll_meter = utils.AverageMeter()
    jaclogdet_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    with tqdm(total=trainloader.batch_sampler.num_labeled) as progress_bar:
        for x1, y in trainloader:

            x1 = x1.to(device)
            y = y.to(device)

            labeled_mask = (y != NO_LABEL)

            optimizer.zero_grad()

            z1 = net(x1)
            sldj = net.module.logdet()

            z_labeled = z1.reshape((len(z1), -1))
            z_labeled = z_labeled[labeled_mask]
            y_labeled = y[labeled_mask]

            logits_labeled = loss_fn.prior.class_logits(z_labeled)
            loss_nll = F.cross_entropy(logits_labeled, y_labeled)

            if use_unlab:
                loss_unsup = loss_fn(z1, sldj=sldj)
                loss = loss_nll * label_weight + loss_unsup
            else:
                loss_unsup = torch.tensor([0.])
                loss = loss_nll

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            preds = torch.argmax(logits_labeled, dim=1)
            acc = (preds == y_labeled).float().mean().item()

            acc_meter.update(acc, x1.size(0))
            loss_meter.update(loss.item(), x1.size(0))
            loss_unsup_meter.update(loss_unsup.item(), x1.size(0))
            loss_nll_meter.update(loss_nll.item(), x1.size(0))
            jaclogdet_meter.update(sldj.mean().item(), x1.size(0))

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(
                                         x1, loss_unsup_meter.avg),
                                     acc=acc_meter.avg)
            progress_bar.update(y_labeled.size(0))

    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, epoch)
    writer.add_scalar("train/loss_nll", loss_nll_meter.avg, epoch)
    writer.add_scalar("train/jaclogdet", jaclogdet_meter.avg, epoch)
    writer.add_scalar("train/acc", acc_meter.avg, epoch)
    writer.add_scalar("train/bpd",
                      utils.bits_per_dim(x1, loss_unsup_meter.avg), epoch)
Exemplo n.º 20
0
    def update(self,
               obs,
               act,
               length_mask,
               target=None,
               mask_input=None,
               hidden=None):
        """
        all input params are numpy arrays
        :param obs: [batch, seq_len, n, m, channel]
        :param act: [batch, seq_len]
        :param length_mask: [batch, seq_len]
        :param target: [batch] or None (when single-target)
        :param mask_input: (optional) [batch, seq_len, feat_dim]
        """
        tt = time.time()
        # convert data to Variables
        batch_size = obs.shape[0]
        seq_len = obs.shape[1]
        total_samples = float(np.sum(length_mask))
        obs = self._create_gpu_tensor(
            obs, return_variable=True)  # [batch, t_max, dims...]
        if hidden is None:
            hidden = self.policy.get_zero_state(
                batch=batch_size,
                return_variable=True,
                hidden_batch_first=self._is_multigpu)
        if target is not None:
            target = self._create_target_tensor(target,
                                                seq_len,
                                                return_variable=True)
        if mask_input is not None:
            mask_input = self._create_feature_tensor(mask_input,
                                                     return_variable=True)
        length_mask = self._create_feature_tensor(
            length_mask, return_variable=True)  #[batch, t_max]

        # create action tensor
        #act = Variable(torch.from_numpy(act).type(LongTensor))  # [batch, t_max]
        act_n = torch.zeros(batch_size, seq_len,
                            self.policy.out_dim).type(FloatTensor)
        ids = torch.from_numpy(np.array(act)).type(LongTensor).view(
            batch_size, seq_len, 1)
        act_n.scatter_(2, ids, 1.0)
        act_n = Variable(act_n)

        time_counter[0] += time.time() - tt

        tt = time.time()

        if self.accu_grad_steps == 0:  # clear grad
            self.optim.zero_grad()

        # forward pass
        # logits: [batch, seq_len, n_act]
        logits, _ = self.net(obs,
                             hidden,
                             return_value=False,
                             sample_action=False,
                             return_tensor=False,
                             target=target,
                             extra_input_feature=mask_input,
                             return_logits=True,
                             hidden_batch_first=self._is_multigpu)

        # compute loss
        #critic_loss = F.smooth_l1_loss(V, R)
        block_size = batch_size * seq_len
        act_size = logits.size(-1)
        flat_logits = logits.view(block_size, act_size)
        logp = torch.sum(F.log_softmax(flat_logits).view(
            batch_size, seq_len, act_size) * act_n,
                         dim=-1) * length_mask
        loss = -torch.sum(logp) / total_samples

        # entropy penalty
        L_ent = torch.sum(
            self.policy.entropy(logits=logits) * length_mask) / total_samples
        if self.args['entropy_penalty'] is not None:
            loss -= self.args['entropy_penalty'] * L_ent

        # L^2 penalty
        L_norm = torch.sum(
            torch.sum(logits * logits, dim=-1) * length_mask) / total_samples
        if self.args['logits_penalty'] is not None:
            loss += self.args['logits_penalty'] * L_norm

        # compute accuracy
        _, max_idx = torch.max(logits.data, dim=-1, keepdim=True)
        L_accu = torch.sum(
            (max_idx == ids).type(FloatTensor) *
            length_mask.data.view(batch_size, seq_len, 1)) / total_samples

        ret_dict = dict(loss=loss.data.cpu().numpy()[0],
                        entropy=L_ent.data.cpu().numpy()[0],
                        logits_norm=L_norm.data.cpu().numpy()[0],
                        accuracy=L_accu)

        # backprop
        if self.grad_batch > 1:
            loss = loss / float(self.grad_batch)
        loss.backward()

        # accumulative stats
        if self.accu_grad_steps == 0:
            self.accu_ret_dict = ret_dict
        else:
            for k in ret_dict:
                self.accu_ret_dict[k] += ret_dict[k]

        self.accu_grad_steps += 1
        if self.accu_grad_steps < self.grad_batch:  # do not update parameter now
            time_counter[1] += time.time() - tt
            return None

        # update stats
        for k in self.accu_ret_dict:
            self.accu_ret_dict[k] /= self.grad_batch
        ret_dict = self.accu_ret_dict
        self.accu_grad_steps = 0

        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip)
        self.optim.step()

        time_counter[1] += time.time() - tt
        return ret_dict
Exemplo n.º 21
0
    def update(self, obs, label):
        """
        all input params are numpy arrays
        :param obs: [batch, n, m, channel] or [batch, stack_frame, n, m, channel]
        :param label: [batch, n_class] (sigmoid) or [batch] (softmax)
        """
        tt = time.time()
        # convert data to Variables
        batch_size = obs.shape[0]
        obs = self._create_gpu_tensor(obs, return_variable=True)  # [batch, channel, n, m]

        # create label tensor
        if self.multi_label:
            t_label = torch.from_numpy(np.array(label)).type(FloatTensor)
        else:
            t_label = torch.from_numpy(np.array(label)).type(LongTensor)
        label = Variable(t_label)

        time_counter[0] += time.time() - tt

        tt = time.time()

        if self.accu_grad_steps == 0:  # clear grad
            self.optim.zero_grad()

        # forward pass
        # logits: [batch, n_class]
        logits = self.policy(obs, return_logits=True)

        # compute loss
        if self.multi_label:
            loss = torch.mean(F.binary_cross_entropy_with_logits(logits, label))
        else:
            loss = torch.mean(F.cross_entropy(logits, label))

        # entropy penalty
        L_ent = torch.mean(self.policy.entropy(logits=logits))
        if self.args['entropy_penalty'] is not None:
            loss -= self.args['entropy_penalty'] * L_ent

        # L^2 penalty
        L_norm = torch.mean(torch.sum(logits * logits, dim=-1))
        if self.args['logits_penalty'] is not None:
            loss += self.args['logits_penalty'] * L_norm

        # compute accuracy
        if self.multi_label:
            max_idx = (logits.data > 0.5).type(FloatTensor)
            total_sample = batch_size * self.out_dim
        else:
            _, max_idx = torch.max(logits.data, dim=-1, keepdim=False)
            total_sample = batch_size
        L_accu = torch.sum((max_idx == t_label).type(FloatTensor)) / batch_size

        ret_dict = dict(loss=loss.data.cpu().numpy()[0],
                        entropy=L_ent.data.cpu().numpy()[0],
                        logits_norm=L_norm.data.cpu().numpy()[0],
                        accuracy=L_accu)

        # backprop
        if self.grad_batch > 1:
            loss = loss / float(self.grad_batch)
        loss.backward()

        # accumulative stats
        if self.accu_grad_steps == 0:
            self.accu_ret_dict = ret_dict
        else:
            for k in ret_dict:
                self.accu_ret_dict[k] += ret_dict[k]

        self.accu_grad_steps += 1
        if self.accu_grad_steps < self.grad_batch:  # do not update parameter now
            time_counter[1] += time.time() - tt
            return None

        # update stats
        for k in self.accu_ret_dict:
            self.accu_ret_dict[k] /= self.grad_batch
        ret_dict = self.accu_ret_dict
        self.accu_grad_steps = 0

        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip)
        self.optim.step()

        time_counter[1] += time.time() - tt
        return ret_dict
Exemplo n.º 22
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * self.args['episode_len']):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, full_act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        full_act_n = Variable(torch.from_numpy(full_act)).type(FloatTensor)
        rew_n = Variable(torch.from_numpy(rew),
                         volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done),
                          volatile=True).type(FloatTensor)

        time_counter[0] += time.time() - tt
        tt = time.time()

        self.optim.zero_grad()

        # train p network
        q_val = self.net(obs_n, action=None, output_critic=True)
        p_loss = -q_val.mean().squeeze()
        p_ent = self.net.entropy().mean().squeeze()
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration
        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)
        p_loss.backward()
        self.net.clear_critic_specific_grad(
        )  # we do not need to compute q_grad for actor!!!
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()

        # train q network
        self.optim.zero_grad()
        common.debugger.print('Grad Stats of Q Update ...', False)
        target_q_next = self.target_net(obs_next_n, output_critic=True)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile = False
        current_q = self.net(obs_n, action=full_act_n, output_critic=True)
        q_norm = (current_q * current_q).mean().squeeze()  # l2 norm
        q_loss = F.smooth_l1_loss(
            current_q,
            target_q) + self.args['critic_penalty'] * q_norm  # huber
        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)
        #q_loss = q_loss * 50
        q_loss.backward()

        # total_loss = q_loss + p_loss
        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()

        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.net,
                        self.target_net,
                        rate=self.target_update_rate)

        common.debugger.print('Stats of Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time() - tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Exemplo n.º 23
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * min(self.args['update_freq'], 10)):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        act_n = torch.from_numpy(act).type(LongTensor)
        rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor)

        time_counter[0] += time.time() - tt
        tt = time.time()

        # compute critic loss
        target_q_next = self.target_net(obs_next_n, only_value=True)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile=False
        current_act, current_q = self.net(obs_n, return_value=True)
        q_norm = (current_q * current_q).mean().squeeze()
        q_loss = F.smooth_l1_loss(current_q, target_q)

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False)
        common.debugger.print('>> Q_Norm = {}'.format(q_norm.data.mean()), False)

        total_loss = q_loss.mean()
        if self.args['critic_penalty'] > 1e-10:
            total_loss += self.args['critic_penalty']*q_norm

        # compute policy loss
        # NOTE: currently 1-step lookahead!!! TODO: multiple-step lookahead
        raw_adv_ts = (rew_n - current_q).data
        #raw_adv_ts = (target_q - current_q).data   # use estimated advantage??
        adv_ts = (raw_adv_ts - raw_adv_ts.mean()) / (raw_adv_ts.std() + 1e-15)
        #current_act.reinforce(adv_ts)
        p_ent = self.net.entropy().mean()
        p_loss = self.net.logprob(act_n)
        p_loss = p_loss * Variable(adv_ts)
        p_loss = p_loss.mean()
        total_loss -= p_loss
        if self.args['ent_penalty'] is not None:
            total_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration
        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False)
        common.debugger.print('>> P_Entropy = {}'.format(p_ent.data.mean()), False)

        # compute gradient
        self.optim.zero_grad()
        #autograd.backward([total_loss, current_act], [torch.ones(1), None])
        total_loss.backward()
        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()
        common.debugger.print('Stats of Model (*after* clip and opt)....', False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() -tt
        tt =time.time()

        # update target networks
        make_update_exp(self.net, self.target_net, rate=self.target_update_rate)
        common.debugger.print('Stats of Target Network (After Update)....', False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time()-tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])