Beispiel #1
0
def run_vqvae(info,
              vqvae_model,
              opt,
              train_buffer,
              valid_buffer,
              num_samples_to_train=10000,
              save_every_samples=1000,
              batches=0):
    if len(info['model_train_cnts']):
        train_cnt = info['model_train_cnts'][-1]
    else:
        train_cnt = 0
    while train_cnt < num_samples_to_train:
        st = time.time()
        batch = train_buffer.get_minibatch(info['MODEL_BATCH_SIZE'])
        opt.zero_grad()
        avg_train_losses, _, _, _ = train_vqvae(vqvae_model, info, batch)
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 5)
        opt.step()
        opt.zero_grad()

        train_cnt += info['MODEL_BATCH_SIZE']
        if (((train_cnt - info['model_last_save']) >= save_every_samples)
                or batches == 0):
            valid_batch = valid_buffer.get_minibatch(info['MODEL_BATCH_SIZE'])
            save_vqvae(info, train_cnt, vqvae_model, opt, avg_train_losses,
                       valid_batch)
        batches += 1
        if not batches % 500:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return vqvae_model, opt
Beispiel #2
0
    def train_gp(self, seed=0, n_iter=300, pred_interval=5, test=None, verbose=True, n_epoch=50):
        set_seed(seed)
        print('SEED=', seed)
        optimizer = opt.Adam(self.model.parameters())
        for i in range(n_iter):
            batch_nll = 0.0
            epoch = 0
            for (X, Y) in self.data:
                if epoch > n_epoch:
                    break
                epoch += 1
                self.model.train()
                optimizer.zero_grad()
                nll = self.gp.NLL_batch(X, Y)
                batch_nll += nll * X.shape[0]
                nll.backward()
                clip_grad_value_(self.model.parameters(), 20)
                optimizer.step()
                torch.cuda.empty_cache()
            if i % pred_interval == 0:
                record = {'nll': batch_nll.item() / self.train['X'].shape[0],
                          'iter': i}

                if test is not None:
                    err = rmse(self.gp(test['X'])[0], test['Y'])
                    record['rmse'] = err.item()

                if verbose:
                    print(record)

                self.history.append(record)
        return self.history
Beispiel #3
0
def train_vqvae(train_cnt):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        vqvae_model.train()
        opt.zero_grad()
        #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch()
        states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(states) - 1).to(DEVICE)
        rec = (2 * reshape_input(pred_states[:, 0][:, None]) - 1).to(DEVICE)
        # dont normalize diff
        diff = (reshape_input(pred_states[:, 1][:, None])).to(DEVICE)
        x_d, z_e_x, z_q_x, latents = vqvae_model(states)
        # (args.nr_logistic_mix/2)*3 is needed for each reconstruction
        z_q_x.retain_grad()
        rec_est = x_d[:, :nmix]
        diff_est = x_d[:, nmix:]
        loss_rec = discretized_mix_logistic_loss(rec_est,
                                                 rec,
                                                 nr_mix=args.nr_logistic_mix,
                                                 DEVICE=DEVICE)
        loss_diff = discretized_mix_logistic_loss(diff_est,
                                                  diff,
                                                  nr_mix=args.nr_logistic_mix,
                                                  DEVICE=DEVICE)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
        loss_rec.backward(retain_graph=True)
        loss_diff.backward(retain_graph=True)
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        loss_list = [
            loss_rec.item() / bs,
            loss_diff.item() / bs,
            loss_2.item() / bs,
            loss_3.item() / bs
        ]
        if batches > 5:
            handle_checkpointing(train_cnt, loss_list)
        train_cnt += len(states)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
Beispiel #4
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    loss = self._compute_loss(batch) / self.batches_per_iter
                    accumulated_loss += loss.item()
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
                log(f'{tb_prefix}Training/Gradient_Norm',
                    self.model.gradient_norm, global_step)
                log(f'{tb_prefix}Training/Parameter_Norm',
                    self.model.parameter_norm, global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    learning_rate = self.iter_scheduler.get_lr()[
                        0]  # type: ignore
                    log(f'{tb_prefix}Training/LR', learning_rate, global_step)
                    self.iter_scheduler.step()  # type: ignore

            # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
def train_acn(train_cnt):
    train_kl_loss = 0.0
    train_rec_loss = 0.0
    init_cnt = train_cnt
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        encoder_model.train()
        prior_model.train()
        pcnn_decoder.train()
        opt.zero_grad()
        states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch()
        states = states.to(DEVICE)
        # 1 channel expected
        next_states = next_states[:,args.number_condition-1:].to(DEVICE)
        actions = actions.to(DEVICE)
        z, u_q = encoder_model(states)
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()

        # add the predicted codes to the input
        yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, class_condition=actions, float_condition=z))
        #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, float_condition=z))
        #print(train_cnt)
        prior_model.codes[relative_indexes-args.number_condition] = u_q.detach().cpu().numpy()
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()
        mix, u_ps, s_ps = prior_model(u_q)
        kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_states, u_q, mix, u_ps, s_ps)
        loss = kl_loss + rec_loss
        loss.backward()
        parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
        clip_grad_value_(parameters, 10)
        train_kl_loss+= kl_loss.item()
        train_rec_loss+= rec_loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_kl_loss = train_kl_loss/float((train_cnt+states.shape[0])-init_cnt)
        avg_train_rec_loss = train_rec_loss/float((train_cnt+states.shape[0])-init_cnt)
        handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss)
        train_cnt+=len(states)

        batches+=1
        if not batches%1000:
            print("finished %s epoch after %s seconds at cnt %s"%(batches, time.time()-st, train_cnt))
    return train_cnt
def train_vqvae(train_cnt):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        vqenc.train()
        pcnn_decoder.train()
        opt.zero_grad()
        states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = reshape_input(states).to(DEVICE)
        # only predict future observation - normalize
        targets = (2 * states[:, -1:] - 1).to(DEVICE)
        #actions = actions.to(DEVICE)
        x_d, z_e_x, z_q_x, latents = vqvae_model(states, targets)
        #z_e_x, z_q_x, latents = vqenc(states)
        #float_condition = latents.view(latents.shape[0], latents.shape[1]*latents.shape[2]).float()
        #x_d = pcnn_decoder(targets, class_condition=actions, float_condition=float_condition)
        z_q_x.retain_grad()
        vqvae_model.spatial_condition.retain_grad()
        loss_1 = discretized_mix_logistic_loss(x_d,
                                               targets,
                                               nr_mix=args.nr_logistic_mix,
                                               DEVICE=DEVICE)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
        #loss_1, loss_2, loss_3 = get_vqvae_loss(x_d, targets, z_e_x, z_q_x, nr_logistic_mix=args.nr_logistic_mix, beta=args.beta, device=DEVICE)
        loss_1.backward(retain_graph=True)
        #vqvae_model.encoder.embedding.zero_grad()
        #z_e_x.backward(z_q_x.grad, retain_graph=True)
        z_e_x.backward(vqvae_model.spatial_condition.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()

        bs = float(x_d.shape[0])
        handle_checkpointing(train_cnt,
                             loss_1.item() / bs,
                             loss_2.item() / bs,
                             loss_3.item() / bs)
        train_cnt += len(states)

        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
def train_acn(train_cnt):
    train_loss = 0
    init_cnt = train_cnt
    st = time.time()
    for batch_idx, (data, label, data_index) in enumerate(train_loader):
        encoder_model.train()
        prior_model.train()
        pcnn_decoder.train()
        lst = time.time()
        data = data.to(DEVICE)
        opt.zero_grad()
        z, u_q = encoder_model(data)
        #yhat_batch = encoder_model.decode(u_q, s_q, data)
        # add the predicted codes to the input
        # TODO - this isn't how you sample pcnn
        yhat_batch = torch.sigmoid(pcnn_decoder(x=data, float_condition=z))

        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()
        prior_model.codes[data_index] = np_uq
        #prior_model.fit_knn(prior_model.codes)
        # output is gmp
        mixtures, u_ps, s_ps = prior_model(u_q)
        kl_reg, rec_loss = acn_gmp_loss_function(yhat_batch, data, u_q,
                                                 mixtures, u_ps, s_ps)
        loss = kl_reg + rec_loss
        if not batch_idx % 10:
            print(train_cnt, batch_idx, kl_reg.item(), rec_loss.item())
        loss.backward()
        parameters = list(encoder_model.parameters()) + list(
            prior_model.parameters()) + list(pcnn_decoder.parameters())
        clip_grad_value_(parameters, 10)
        train_loss += loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_loss = train_loss / float((train_cnt + data.shape[0]) -
                                            init_cnt)
        print('batch', train_cnt, avg_train_loss)
        handle_checkpointing(train_cnt, avg_train_loss)
        train_cnt += len(data)
    print(train_loss)
    print("finished epoch after %s seconds at cnt %s" %
          (time.time() - st, train_cnt))
    return train_cnt
Beispiel #8
0
def train(train_input_images, train_actions, encoder, decoder, criterion,
          optimizer, model_paras, config):

    times = int(train_input_images.shape[1] / config.seq_len)

    train_loss = 0

    for i in range(times):
        train_input_image_seq = train_input_images[:,
                                                   i * config.seq_len:(i + 1) *
                                                   config.seq_len].cuda()
        train_action_seq = {action: torch.cat((config.init_y[action], values[:, i*config.seq_len:(i+1)*config.seq_len]), dim=1).cuda()\
                            for action, values in train_actions.items()}

        encoder.zero_grad()
        decoder.zero_grad()
        encoder_outputs = encoder.forward(train_input_image_seq,
                                          config.decoder_batch_size,
                                          config.seq_len)
        y = decoder.forward(encoder_outputs, train_action_seq,
                            config.decoder_batch_size, config.seq_len)

        losses = []
        for action in config.y_keys_info.keys():
            losses.append(criterion(y[action], train_action_seq[action][:,
                                                                        1:]))
        total_loss = sum(losses)
        total_loss.backward()

        train_loss += total_loss.item()

        clip_grad_value_(model_paras, config.clip_value)
        optimizer.step()

        accuracy = {}
        for action in config.y_keys_info.keys():
            _, y_pred = y[action].max(dim=1)
            accuracy[action] = (y_pred == train_action_seq[action][:, 1:]).sum(
            ).item() / (config.decoder_batch_size * config.seq_len)
        print(accuracy)

    return train_loss / times
def train_vqvae(train_cnt):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        vqvae_model.train()
        opt.zero_grad()
        states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(states[:, -1:]) - 1).to(DEVICE)
        x_d, z_e_x, z_q_x, latents = vqvae_model(states)
        z_q_x.retain_grad()
        loss_1 = discretized_mix_logistic_loss(x_d,
                                               states,
                                               nr_mix=args.nr_logistic_mix,
                                               DEVICE=DEVICE)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
        loss_1.backward(retain_graph=True)
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        handle_checkpointing(train_cnt,
                             loss_1.item() / bs,
                             loss_2.item() / bs,
                             loss_3.item() / bs)
        train_cnt += len(states)

        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
def pt_latent_learn(
):  # latent_states, actions, rewards, latent_next_states, terminal_flags, masks):
    opt.zero_grad()
    batch = replay_memory.get_minibatch(info['BATCH_SIZE'])
    terminal_flags = torch.Tensor(batch[4].astype(np.int)).to(info['DEVICE'])
    masks = torch.FloatTensor(batch[5].astype(np.int)).to(info['DEVICE'])
    # do losses based on next state estimate
    # backward() for state_representation happens in train_vqvae - still need to
    # step
    avg_train_losses, next_z_q_x, next_z_e_x, pt_data = train_vqvae(
        vqvae_model, info, batch)
    states, actions, rewards, next_states = pt_data
    rewards = rewards.float()
    # get representation for current state
    x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model(
        states)
    # min history to learn is 200,000 frames in dqn - 50000 steps
    losses = [0.0 for _ in range(info['N_ENSEMBLE'])]
    q_policy_vals = policy_net(z_e_x, None)
    next_q_target_vals = target_net(next_z_e_x, None)
    next_q_policy_vals = policy_net(next_z_e_x, None)
    z_e_x.retain_grad()
    next_z_e_x.retain_grad()
    # trying to work on e rather than q to look at grads
    # z_q_x dows not give .grads to vqvae_model
    #q_policy_vals = policy_net(z_q_x, None)
    #next_q_target_vals = target_net(next_z_q_x, None)
    #next_q_policy_vals = policy_net(next_z_q_x, None)
    #z_q_x.retain_grad()
    #next_z_q_x.retain_grad()
    cnt_losses = []
    for k in range(info['N_ENSEMBLE']):
        #TODO finish masking
        total_used = torch.sum(masks[:, k])
        if total_used > 0.0:
            next_q_vals = next_q_target_vals[k].data
            if info['DOUBLE_DQN']:
                next_actions = next_q_policy_vals[k].data.max(1, True)[1]
                next_qs = next_q_vals.gather(1, next_actions).squeeze(1)
            else:
                next_qs = next_q_vals.max(1)[0]  # max returns a pair

            preds = q_policy_vals[k].gather(1, actions[:, None]).squeeze(1)
            targets = rewards + info['GAMMA'] * next_qs * (1 - terminal_flags)
            l1loss = F.smooth_l1_loss(preds, targets, reduction='mean')
            full_loss = masks[:, k] * l1loss
            loss = torch.sum(full_loss / total_used)
            cnt_losses.append(loss)
            losses[k] = loss.cpu().detach().item()

    loss = sum(cnt_losses) / info['N_ENSEMBLE']
    loss.backward()
    vqvae_parameters = list(vqvae_model.parameters())
    clip_grad_value_(vqvae_parameters, 5)
    for param in policy_net.core_net.parameters():
        if param.grad is not None:
            # divide grads in core
            param.grad.data *= 1.0 / float(info['N_ENSEMBLE'])
    nn.utils.clip_grad_norm_(policy_net.parameters(), info['CLIP_GRAD'])
    opt.step()
    return np.mean(losses) + np.sum(avg_train_losses)
Beispiel #11
0
def train_forward(train_cnt, conv_forward_model, opt, latents, actions,
                  rewards, next_latents):
    st = time.time()
    batches = 0
    while train_cnt < args.num_examples_to_train:
        conv_forward_model.train()
        opt.zero_grad()
        # we want the forward model to produce a next latent in which the vq
        # model can determine the action we gave it.
        latents = torch.FloatTensor(latents[:, None]).to(DEVICE)
        # next_latents is long because of prediction
        next_latents = torch.LongTensor(next_latents[:, None]).to(DEVICE)
        rewards = torch.LongTensor(rewards).to(DEVICE)
        actions = torch.LongTensor(actions).to(DEVICE)
        # put actions into channel for conditioning
        bs, _, h, w = latents.shape
        channel_actions = torch.zeros((bs, num_actions, h, w)).to(DEVICE)
        for a in range(num_actions):
            channel_actions[actions == a, a] = 1
        channel_rewards = torch.zeros((bs, num_rewards, h, w)).to(DEVICE)
        for r in range(num_rewards):
            channel_rewards[rewards == r, r] = 1

        # combine input together
        state_input = torch.cat(
            (channel_actions, channel_rewards, latents, next_latents), dim=1)
        bs = float(latents.shape[0])

        pred_next_latents = conv_forward_model(state_input)
        # pred_next_latents shape is bs, c, h, w - need to permute shape for
        # don't optimize vqmodel - just optimize against its understanding of
        # the latent data
        with torch.no_grad():
            N, _, H, W = latents.shape
            C = vq_largs.num_z
            pred_next_latent_inds = torch.argmax(pred_next_latents, dim=1)
            x_tilde, pred_z_q_x, pred_actions, pred_rewards = vqvae_model.decode_clusters(
                pred_next_latent_inds, N, H, W, C)

        # should be able to predict the input action that got us to this
        # timestep
        loss_act = args.alpha_act * F.nll_loss(
            pred_actions, actions, weight=actions_weight)
        loss_reward = args.alpha_rew * F.nll_loss(
            pred_rewards, rewards, weight=rewards_weight)

        # determine which values of latents change over this time step
        ts_change = (torch.abs(latents.long() - next_latents) > 1).view(-1)

        pred_next_latents = pred_next_latents.permute(0, 2, 3, 1).contiguous()
        next_latents = next_latents.permute(0, 2, 3, 1).contiguous()
        latents = latents.permute(0, 2, 3, 1).contiguous()
        loss_rec = args.alpha_rec * F.nll_loss(pred_next_latents.view(
            -1, num_k),
                                               next_latents.view(-1),
                                               reduction='mean')

        # we want to penalize when these are wrong in particular
        loss_diff_rec = args.alpha_rec * F.nll_loss(
            pred_next_latents.view(-1, num_k)[ts_change == 1],
            next_latents.view(-1)[ts_change == 1],
            reduction='mean')

        loss = loss_reward + loss_act + loss_rec + loss_diff_rec
        loss.backward(retain_graph=True)
        parameters = list(conv_forward_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        loss_list = [
            loss_reward.item() / bs,
            loss_act.item() / bs,
            loss_rec.item() / bs,
            loss_diff_rec.item() / bs
        ]
        if batches > 1000:
            handle_checkpointing(train_cnt, loss_list)
        train_cnt += bs
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
Beispiel #12
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()
        metrics_with_states: List[Tuple] = [
            (metric, {}) for metric in self.training_metrics
        ]
        self._last_train_log_step = 0

        log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        log_prefix += 'Training'

        with torch.enable_grad():
            for i in range(self.iter_per_step):

                # print('MEMORY ALLOCATED %f' % float(torch.cuda.memory_allocated() / BYTES_IN_GB))
                # print('MEMORY CACHED %f' % float(torch.cuda.memory_cached() / BYTES_IN_GB))

                t = time.time()
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):

                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)

                    if self.device_count == 1:
                        batch = self._batch_to_device(batch)

                    # Compute loss
                    if self.top_k == 'None':
                        _, _, loss = self._compute_batch(
                            batch, metrics_with_states)
                    else:
                        loss = self._compute_kl_loss(batch)
                    print('LOSS')
                    print(loss)
                    accumulated_loss += loss.item() / self.batches_per_iter

                    loss.backward()
                    # try:
                    #     loss.backward()
                    # except RuntimeError:
                    #     torch.cuda.empty_cache()
                    #     print('EMPTIED CACHE FOR LOSS')
                    #     continue

                # Log loss
                global_step = (self.iter_per_step * self._step) + i
                self.beta = self.get_beta(global_step)

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{log_prefix}/Loss', accumulated_loss, global_step)
                if self.device_count > 1:
                    log(f'{log_prefix}/Gradient_Norm',
                        self.model.module.gradient_norm, global_step)
                    log(f'{log_prefix}/Parameter_Norm',
                        self.model.module.parameter_norm, global_step)
                else:
                    log(f'{log_prefix}/Gradient_Norm',
                        self.model.gradient_norm, global_step)
                    log(f'{log_prefix}/Parameter_Norm',
                        self.model.parameter_norm, global_step)
                log(f'{log_prefix}/Beta', self.beta, global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    lr = self.optimizer.param_groups[0]['lr']  # type: ignore
                    log(f'{log_prefix}/LR', lr, global_step)
                    self.iter_scheduler.step()  # type: ignore

                # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
                # logging train metrics
                if self.extra_training_metrics_log_interval > self._last_train_log_step:
                    self._log_metrics(log_prefix, metrics_with_states,
                                      global_step)
                    self._last_train_log_step = i

                print('TOTAL TIME: %f' % (time.time() - t))
            if self._last_train_log_step != i:
                # log again at end of step, if not logged at the end of
                # step before
                self._log_metrics(log_prefix, metrics_with_states, global_step)
def train_acn(train_cnt):
    train_kl_loss = 0.0
    train_rec_loss = 0.0
    init_cnt = train_cnt
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        encoder_model.train()
        prior_model.train()
        pcnn_decoder.train()
        opt.zero_grad()
        lst = time.time()
        data, label, data_index, is_new_epoch = data_loader.next_unique_batch()
        if is_new_epoch:
            #    prior_model.new_epoch()
            print(train_cnt, 'train, is new epoch',
                  prior_model.available_indexes.shape)
        data = data.to(DEVICE)
        label = label.to(DEVICE)
        #  inf happens sometime after 0001,680,896
        z, u_q = encoder_model(data)
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()

        # add the predicted codes to the input
        yhat_batch = torch.sigmoid(pcnn_decoder(x=label, float_condition=z))
        #print(train_cnt)
        prior_model.codes[data_index -
                          args.number_condition] = u_q.detach().cpu().numpy()
        #mixtures, u_ps, s_ps = prior_model(u_q)
        #loss = acn_gmp_loss_function(yhat_batch, label, u_q, mixtures, u_ps, s_ps)
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()
        #loss.backward()


#        parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
#        clip_grad_value_(parameters, 10)
#        train_loss+= loss.item()
        mix, u_ps, s_ps = prior_model(u_q)
        #kl_loss, rec_loss = acn_loss_function(yhat_batch, data, u_q, u_ps, s_ps)
        kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, label, u_q, mix,
                                                  u_ps, s_ps)
        loss = kl_loss + rec_loss
        loss.backward()
        parameters = list(encoder_model.parameters()) + list(
            prior_model.parameters()) + list(pcnn_decoder.parameters())
        clip_grad_value_(parameters, 10)
        train_kl_loss += kl_loss.item()
        train_rec_loss += rec_loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_kl_loss = train_kl_loss / float((train_cnt + data.shape[0]) -
                                                  init_cnt)
        avg_train_rec_loss = train_rec_loss / float(
            (train_cnt + data.shape[0]) - init_cnt)
        handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss)
        train_cnt += len(data)

        # add batch size because it hasn't been added to train cnt yet
        #        avg_train_loss = train_loss/float((train_cnt+data.shape[0])-init_cnt)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
Beispiel #14
0
    def train_epoch(self, model: nn.Module, train_loader: DataLoader,
                    val_clean_loader: DataLoader, val_triggered_loader: DataLoader,
                    epoch_num: int, use_amp: bool = False):
        """
        Runs one epoch of training on the specified model

        :param model: the model to train for one epoch
        :param train_loader: a DataLoader object pointing to the training dataset
        :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean
        :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered
        :param epoch_num: the epoch number that is being trained
        :param use_amp: if True use automated mixed precision for FP16 training.
        :return: a list of statistics for batches where statistics were computed
        """

        pid = os.getpid()
        train_dataset_len = len(train_loader.dataset)
        loop = tqdm(train_loader, disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar)

        scaler = None
        if use_amp:
            scaler = torch.cuda.amp.GradScaler()

        train_n_correct, train_n_total = None, None
        sum_batchmean_train_loss = 0
        running_train_acc = 0
        num_batches = len(train_loader)
        model.train()
        for batch_idx, (x, y_truth) in enumerate(loop):
            x = x.to(self.device)
            y_truth = y_truth.to(self.device)

            # put network into training mode & zero out previous gradient computations
            self.optimizer.zero_grad()

            # get predictions based on input & weights learned so far
            if use_amp:
                with torch.cuda.amp.autocast():
                    y_hat = model(x)
                    # compute metrics
                    batch_train_loss = self._eval_loss_function(y_hat, y_truth)
            else:
                y_hat = model(x)
                # compute metrics
                batch_train_loss = self._eval_loss_function(y_hat, y_truth)

            sum_batchmean_train_loss += batch_train_loss.item()

            running_train_acc, train_n_total, train_n_correct = _running_eval_acc(y_hat, y_truth,
                                                                                  n_total=train_n_total,
                                                                                  n_correct=train_n_correct,
                                                                                  soft_to_hard_fn=self.soft_to_hard_fn,
                                                                                  soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs)

            # if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc):
            #     _save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc,
            #                   train_n_total, train_n_correct, model)

            # compute gradient
            if use_amp:
                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
                # Backward passes under autocast are not recommended.
                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
                scaler.scale(batch_train_loss).backward()
            else:
                if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc):
                    _save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc,
                                  train_n_total, train_n_correct, model)

                batch_train_loss.backward()

            # perform gradient clipping if configured
            if self.optimizer_cfg.training_cfg.clip_grad:
                if use_amp:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(self.optimizer)

                if self.optimizer_cfg.training_cfg.clip_type == 'norm':
                    # clip_grad_norm_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_norm_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val,
                                                    **self.optimizer_cfg.training_cfg.clip_kwargs)
                elif self.optimizer_cfg.training_cfg.clip_type == 'val':
                    # clip_grad_val_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_value_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val)
                else:
                    msg = "Unknown clipping type for gradient clipping!"
                    logger.error(msg)
                    raise ValueError(msg)

            if use_amp:
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(self.optimizer)
                # Updates the scale for next iteration.
                scaler.update()
            else:
                self.optimizer.step()

            loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs))
            loop.set_postfix(avg_train_loss=batch_train_loss.item())

            # report batch statistics to tensorflow
            if self.tb_writer:
                try:
                    batch_num = int(epoch_num * num_batches + batch_idx)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss',
                                              batch_train_loss.item(), global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc',
                                              running_train_acc, global_step=batch_num)
                except:
                    # TODO: catch specific expcetions
                    pass

            if batch_idx % self.num_batches_per_logmsg == 0:
                logger.info('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'.format(
                    pid, epoch_num, batch_idx * len(x), train_dataset_len,
                                    100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc))

        train_stats = EpochTrainStatistics(running_train_acc, sum_batchmean_train_loss / float(num_batches))

        # if we have validation data, we compute on the validation dataset
        num_val_batches_clean = len(val_clean_loader)
        if num_val_batches_clean > 0:
            logger.info('Running Validation on Clean Data')
            running_val_clean_acc, _, _, val_clean_loss = \
                _eval_acc(val_clean_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info("No dataset computed for validation on clean dataset!")
            running_val_clean_acc = None
            val_clean_loss = None

        num_val_batches_triggered = len(val_triggered_loader)
        if num_val_batches_triggered > 0:
            logger.info('Running Validation on Triggered Data')
            running_val_triggered_acc, _, _, val_triggered_loss = \
                _eval_acc(val_triggered_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info("No dataset computed for validation on triggered dataset!")
            running_val_triggered_acc = None
            val_triggered_loss = None

        validation_stats = EpochValidationStatistics(running_val_clean_acc, val_clean_loss,
                                                     running_val_triggered_acc, val_triggered_loss)
        if num_val_batches_clean > 0:
            logger.info('{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}'.format(
                pid, epoch_num, val_clean_loss, running_val_clean_acc))
        if num_val_batches_triggered > 0:
            logger.info('{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}'.format(
                pid, epoch_num, val_triggered_loss, running_val_triggered_acc))

        if self.tb_writer:
            try:
                batch_num = int((epoch_num + 1) * num_batches)
                if num_val_batches_clean > 0:
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-clean-val-loss', val_clean_loss, global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-clean-val_acc', running_val_clean_acc, global_step=batch_num)
                if num_val_batches_triggered > 0:
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-triggered-val-loss', val_triggered_loss, global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-triggered-val_acc', running_val_triggered_acc, global_step=batch_num)
            except:
                pass

        # update the lr-scheduler if necessary
        if self.lr_scheduler is not None:
            if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None:
                self.lr_scheduler.step()
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_acc':
                val_acc = validation_stats.get_val_acc()
                if val_acc is not None:
                    self.lr_scheduler.step(val_acc)
                else:
                    msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_loss':
                val_loss = validation_stats.get_val_loss()
                if val_loss is not None:
                    self.lr_scheduler.step(val_loss)
                else:
                    msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            else:
                msg = "Unknown mode for calling lr_scheduler!"
                logger.error(msg)
                raise ValueError(msg)

        return train_stats, validation_stats
    def train_epoch(self, model: nn.Module, train_loader: TextDataIterator, val_loader: TextDataIterator,
                    epoch_num: int, progress_bar_disable: bool = False):
        """
        Runs one epoch of training on the specified model

        :param model: the model to train for one epoch
        :param train_loader: a DataLoader object pointing to the training dataset
        :param val_loader: a DataLoader object pointing to the validation dataset
        :param epoch_num: the epoch number that is being trained
        :param progress_bar_disable: if True, disables the progress bar
        :return: a list of statistics for batches where statistics were computed
        """

        pid = os.getpid()
        train_dataset_len = len(train_loader.dataset)
        loop = tqdm(train_loader, disable=progress_bar_disable)

        train_n_correct, train_n_total = None, None
        val_n_correct, val_n_total = None, None
        sum_batchmean_train_loss = 0
        running_train_acc = 0
        num_batches = len(train_loader)
        # put network into training mode
        model.train()
        for batch_idx, batch in enumerate(loop):
            # zero out previous gradient computations
            self.optimizer.zero_grad()

            # get predictions based on input & weights learned so far
            if model.packed_padded_sequences:
                text, text_lengths = batch.text
                x = (text, text_lengths)
                predictions = model(text, text_lengths).squeeze(1)
            else:
                x = batch.text
                predictions = model(batch.text).squeeze(1)

            # compute metrics
            batch_train_loss = self._eval_loss_function(predictions, batch.label)
            sum_batchmean_train_loss += batch_train_loss.item()
            running_train_acc, train_n_total, train_n_correct = \
                _running_eval_acc(predictions, batch.label, n_total=train_n_total, n_correct=train_n_correct,
                                  soft_to_hard_fn=self.optimizer_cfg.training_cfg.soft_to_hard_fn,
                                  soft_to_hard_fn_kwargs=self.optimizer_cfg.training_cfg.soft_to_hard_fn_kwargs)

            if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc):
                _save_nandata(x, predictions, batch.label, batch_train_loss, sum_batchmean_train_loss, running_train_acc,
                              train_n_total, train_n_correct, model)

            # compute gradient
            batch_train_loss.backward()

            # perform gradient clipping if configured
            if self.optimizer_cfg.training_cfg.clip_grad:
                if self.optimizer_cfg.training_cfg.clip_type == 'norm':
                    # clip_grad_norm_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_norm_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val,
                                                    **self.optimizer_cfg.training_cfg.clip_kwargs)
                elif self.optimizer_cfg.training_cfg.clip_type == 'val':
                    # clip_grad_val_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_value_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val)
                else:
                    msg = "Unknown clipping type for gradient clipping!"
                    logger.error(msg)
                    raise ValueError(msg)

            self.optimizer.step()

            loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs))
            loop.set_postfix(avg_train_loss=batch_train_loss.item())

            # report batch statistics to tensorboard
            if self.tb_writer:
                try:
                    batch_num = int(epoch_num * num_batches + batch_idx)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss',
                                              batch_train_loss.item(), global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc',
                                              running_train_acc, global_step=batch_num)
                except:
                    # TODO: catch specific exceptions!
                    pass

            if batch_idx % self.num_batches_per_logmsg == 0:
                logger.info('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'.format(
                    pid, epoch_num, batch_idx * len(batch), train_dataset_len,
                                    100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc))
        train_stats = EpochTrainStatistics(running_train_acc, sum_batchmean_train_loss / float(num_batches))

        # if we have validation data, we compute on the validation dataset
        validation_stats = None
        num_val_batches = len(val_loader)
        if num_val_batches > 0:
            logger.info('Running validation')
            val_acc, _, _, val_loss = TorchTextOptimizer._eval_acc(val_loader, model, device=self.device,
                                                                   soft_to_hard_fn=self.optimizer_cfg.training_cfg.soft_to_hard_fn,
                                                                   soft_to_hard_fn_kwargs=self.optimizer_cfg.training_cfg.soft_to_hard_fn_kwargs,
                                                                   loss_fn=self._eval_loss_function)
            validation_stats = EpochValidationStatistics(val_acc, val_loss, None, None)

            logger.info('{}\tTrain Epoch: {} \tValLoss: {:.6f}\tValAcc: {:.6f}'.format(
                pid, epoch_num, val_loss, val_acc))

            if self.tb_writer:
                try:
                    batch_num = int((epoch_num + 1) * num_batches)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-validation_loss', val_loss, global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-validation_acc', val_acc, global_step=batch_num)
                except:
                    # TODO: catch specific exceptions!
                    pass

        # update the lr-scheduler if necessary
        if self.lr_scheduler is not None:
            if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None:
                self.lr_scheduler.step()
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_acc':
                if num_val_batches > 0:  # this check ensures that this variable is defined
                    self.lr_scheduler.step(val_acc)
                else:
                    msg = "val_acc not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_loss':
                if num_val_batches > 0:
                    self.lr_scheduler.step(val_loss)
                else:
                    msg = "val_loss not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            else:
                msg = "Unknown mode for calling lr_scheduler!"
                logger.error(msg)
                raise ValueError(msg)

        return train_stats, validation_stats
Beispiel #16
0
def train():

    input_text, input_rems_text = get_data(train=True)

    dec_idx_to_word, dec_word_to_idx, dec_tok_text, dec_bias = tokenize_text(
        input_text, lower_case=True, vsize=20000)
    dec_padded_text = pad_text(dec_tok_text)
    dec_vocab_size = len(dec_idx_to_word)

    enc_idx_to_word, enc_word_to_idx, enc_tok_text, _ = tokenize_text(
        input_rems_text)
    enc_padded_text = pad_text(enc_tok_text)
    enc_vocab_size = len(enc_idx_to_word)

    dec_text_tensor = torch.tensor(dec_padded_text, requires_grad=False)
    if cuda:
        dec_text_tensor = dec_text_tensor.cuda(device=device)

    enc, dec = build_model(enc_vocab_size, dec_vocab_size, dec_bias=dec_bias)
    enc_optim, dec_optim, lossfunc = build_trainers(enc, dec)

    num_batches = enc_padded_text.shape[0] / BATCH_SIZE

    sm_loss = None
    enc.train()
    dec.train()
    for epoch in xrange(0, 13):
        print "Starting New Epoch: %d" % epoch

        order = np.arange(enc_padded_text.shape[0])
        np.random.shuffle(order)
        enc_padded_text = enc_padded_text[order]
        dec_text_tensor.data = dec_text_tensor.data[order]

        for i in xrange(num_batches):
            s = i * BATCH_SIZE
            e = (i + 1) * BATCH_SIZE

            _, enc_pp, dec_pp, enc_lengths = make_packpadded(
                s, e, enc_padded_text, dec_text_tensor)

            enc.zero_grad()
            dec.zero_grad()

            hid = enc.initHidden(BATCH_SIZE)

            out_enc, hid_enc = enc.forward(enc_pp, hid, enc_lengths)

            hid_enc = torch.cat([hid_enc[0, :, :], hid_enc[1, :, :]],
                                dim=1).unsqueeze(0)
            out_dec, hid_dec, attn = dec.forward(dec_pp[:, :-1], hid_enc,
                                                 out_enc)

            out_perm = out_dec.permute(0, 2, 1)
            dec_text_tensor.shape
            loss = lossfunc(out_perm, dec_pp[:, 1:])

            if sm_loss is None:
                sm_loss = loss.data
            else:
                sm_loss = sm_loss * 0.95 + 0.05 * loss.data

            loss.backward()
            clip_grad_value_(enc_optim.param_groups[0]['params'], 5.0)
            clip_grad_value_(dec_optim.param_groups[0]['params'], 5.0)
            enc_optim.step()
            dec_optim.step()

            #del loss
            if i % 100 == 0:
                print "Epoch: %.3f" % (i / float(num_batches) +
                                       epoch, ), "Loss:", sm_loss
                print "GEN:", untokenize(
                    torch.argmax(out_dec, dim=2)[0, :], dec_idx_to_word)
                #print "GEN:", untokenize(torch.argmax(out_dec,dim=2)[1,:], dec_idx_to_word)
                print "GT:", untokenize(dec_pp[0, :], dec_idx_to_word)
                print "IN:", untokenize(enc_pp[0, :], enc_idx_to_word)

                print torch.argmax(attn[0], dim=1)
                print "--------------"
        save_state(enc, dec, enc_optim, dec_optim, dec_idx_to_word,
                   dec_word_to_idx, enc_idx_to_word, enc_word_to_idx, epoch)
    def train_epoch(self,
                    model: nn.Module,
                    train_loader: TextDataIterator,
                    val_clean_loader: TextDataIterator,
                    val_triggered_loader: TextDataIterator,
                    epoch_num: int,
                    progress_bar_disable: bool = False,
                    use_amp: bool = False):
        """
        Runs one epoch of training on the specified model

        :param model: the model to train for one epoch
        :param train_loader: a DataLoader object pointing to the training dataset
        :param val_loader: a DataLoader object pointing to the validation dataset
        :param epoch_num: the epoch number that is being trained
        :param progress_bar_disable: if True, disables the progress bar
        :return: a list of statistics for batches where statistics were computed
        """

        pid = os.getpid()
        train_dataset_len = len(train_loader.dataset)

        scaler = None
        if use_amp:
            scaler = torch.cuda.amp.GradScaler()

        train_n_correct, train_n_total = None, None
        val_n_correct, val_n_total = None, None
        sum_batchmean_train_loss = 0
        running_train_acc = 0
        num_batches = len(train_loader)

        # put network and embedding into training mode
        model.train()

        loop = tqdm(
            train_loader,
            disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar)

        for batch_idx, (input_ids, attention_mask, labels,
                        label_mask) in enumerate(loop):
            # zero out previous gradient computations
            self.optimizer.zero_grad()

            # batch = tuple(t.to(self.device) for t in batch)
            # input_ids, attention_mask, labels, label_mask, valid_ids, token_type_ids = batch
            input_ids = input_ids.to(self.device)
            attention_mask = attention_mask.to(self.device)
            labels = labels.to(self.device)
            label_mask = label_mask.to(self.device)

            if use_amp:
                with torch.cuda.amp.autocast():
                    batch_train_loss, predictions = model(
                        input_ids,
                        attention_mask=attention_mask,
                        labels=labels)
            else:
                batch_train_loss, predictions = model(
                    input_ids, attention_mask=attention_mask, labels=labels)

            sum_batchmean_train_loss += batch_train_loss.item()
            running_train_acc, train_n_total, train_n_correct = \
                _running_eval_acc(predictions, labels, label_mask, n_total=train_n_total, n_correct=train_n_correct,
                                  soft_to_hard_fn=self.optimizer_cfg.training_cfg.soft_to_hard_fn,
                                  soft_to_hard_fn_kwargs=self.optimizer_cfg.training_cfg.soft_to_hard_fn_kwargs)

            # backward pass
            if use_amp:
                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
                # Backward passes under autocast are not recommended.
                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
                scaler.scale(batch_train_loss).backward()
            else:
                if np.isnan(sum_batchmean_train_loss) or np.isnan(
                        running_train_acc):
                    # TODO: Figure out how to pass the original text ... input ids is tokenized
                    _save_nandata(input_ids, predictions, labels,
                                  batch_train_loss, sum_batchmean_train_loss,
                                  running_train_acc, train_n_total,
                                  train_n_correct, model)
                batch_train_loss.backward()

            # perform gradient clipping if configured
            if self.optimizer_cfg.training_cfg.clip_grad:
                if use_amp:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(self.optimizer)

                if self.optimizer_cfg.training_cfg.clip_type == 'norm':
                    # clip_grad_norm_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_norm_(
                        model.parameters(),
                        self.optimizer_cfg.training_cfg.clip_val,
                        **self.optimizer_cfg.training_cfg.clip_kwargs)
                elif self.optimizer_cfg.training_cfg.clip_type == 'val':
                    # clip_grad_val_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_value_(
                        model.parameters(),
                        self.optimizer_cfg.training_cfg.clip_val)
                else:
                    msg = "Unknown clipping type for gradient clipping!"
                    logger.error(msg)
                    raise ValueError(msg)
            if use_amp:
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(self.optimizer)
                # Updates the scale for next iteration.
                scaler.update()
            else:
                self.optimizer.step()

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            # report batch statistics to tensorflow
            if self.tb_writer:
                try:
                    batch_num = int(epoch_num * num_batches + batch_idx)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-train_loss',
                        batch_train_loss.item(),
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-running_train_acc',
                        running_train_acc,
                        global_step=batch_num)
                except:
                    # TODO: catch specific expcetions
                    pass

            loop.set_description('Epoch {}/{}'.format(epoch_num + 1,
                                                      self.num_epochs))
            loop.set_postfix(avg_train_loss=batch_train_loss.item())

            if batch_idx % self.num_batches_per_logmsg == 0:
                # TODO: Determine best way to get text (possibly from tokenizer if needed...)
                # acc_per_label = "Accuracy Per item: "
                # for k in train_n_total.keys():
                #      acc_per_label += "{}: {}, ".format(k, 0 if train_n_total[k] == 0 else float(train_n_correct[k]) / float(train_n_total[k]))

                logger.info(
                    '{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'
                    .format(
                        pid,
                        epoch_num,
                        batch_idx * input_ids.shape[0],
                        train_dataset_len,
                        # pid, epoch_num, batch_idx * len(text), train_dataset_len,
                        100. * batch_idx / num_batches,
                        batch_train_loss.item(),
                        running_train_acc))

        train_stats = EpochTrainStatistics(
            running_train_acc, sum_batchmean_train_loss / float(num_batches))
        clean_counts = None
        triggered_counts = None
        # if we have validation data, we compute on the validation dataset
        num_val_batches_clean = len(val_clean_loader)
        if num_val_batches_clean > 0:
            logger.info('Running Validation on Clean Data')
            running_val_clean_acc, clean_n_total, clean_n_correct, val_clean_loss, clean_counts = \
                self._eval_acc(val_clean_loader, model)
        else:
            logger.info("No dataset computed for validation on clean dataset!")
            running_val_clean_acc = None
            val_clean_loss = None

        num_val_batches_triggered = len(val_triggered_loader)
        if num_val_batches_triggered > 0:
            logger.info('Running Validation on Triggered Data')
            running_val_triggered_acc, triggered_n_total, triggered_n_correct, val_triggered_loss, triggered_counts = \
                self._eval_acc(val_triggered_loader, model)
        else:
            logger.info(
                "No dataset computed for validation on triggered dataset!")
            running_val_triggered_acc = None
            val_triggered_loss = None

        validation_stats = EpochValidationStatistics(
            running_val_clean_acc, val_clean_loss, running_val_triggered_acc,
            val_triggered_loss)
        if num_val_batches_clean > 0:
            acc_per_label = "{"
            for k in train_n_total.keys():
                acc_per_label += "{}: {}, ".format(
                    k,
                    0 if clean_n_total[k] == 0 else float(clean_n_correct[k]) /
                    float(clean_n_total[k]))
            acc_per_label += "}"
            logger.info(
                '{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}\nCleanPerLabelAcc: {}\nclean_total: {}\nclean_correct: {}'
                .format(pid, epoch_num, val_clean_loss, running_val_clean_acc,
                        acc_per_label, clean_n_total, clean_n_correct))
        if num_val_batches_triggered > 0:
            acc_per_label = "{"
            for k in triggered_n_total.keys():
                acc_per_label += "{}: {}, ".format(
                    k, 0 if triggered_n_total[k] == 0 else
                    float(triggered_n_correct[k]) /
                    float(triggered_n_total[k]))
            acc_per_label += "}"
            logger.info(
                '{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}\tTriggeredPerLabelAcc: {}'
                .format(pid, epoch_num, val_triggered_loss,
                        running_val_triggered_acc, acc_per_label))

        if self.tb_writer:
            try:
                batch_num = int((epoch_num + 1) * num_batches)
                if num_val_batches_clean > 0:
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-clean-val-loss',
                        val_clean_loss,
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-clean-val_acc',
                        running_val_clean_acc,
                        global_step=batch_num)
                if num_val_batches_triggered > 0:
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-triggered-val-loss',
                        val_triggered_loss,
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-triggered-val_acc',
                        running_val_triggered_acc,
                        global_step=batch_num)
            except:
                pass

        # Add epoch to stats
        self.ner_metrics.add_epoch_stats(epoch_num, clean_counts,
                                         triggered_counts)

        return train_stats, validation_stats
Beispiel #18
0
    def train_gp(self,
                 seed=0,
                 n_iter=300,
                 pred_interval=5,
                 test=False,
                 verbose=True,
                 n_epoch=20):
        set_seed(seed)
        print('SEED=', seed)
        optimizer = opt.Adam(self.model.parameters())
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        pred_start = torch.cuda.Event(enable_timing=True)
        pred_end = torch.cuda.Event(enable_timing=True)
        pred_time = 0.0
        start.record()
        for i in range(n_iter):
            batch_nll = 0.0
            epoch = 0
            for (X, Y) in self.data:
                if epoch > n_epoch:
                    break
                epoch += 1
                self.model.train()
                optimizer.zero_grad()
                nll = self.gp.NLL(X, Y)
                batch_nll += nll * X.shape[0]
                nll.backward()
                clip_grad_value_(self.model.parameters(), 10)
                optimizer.step()
                torch.cuda.empty_cache()
            if i % pred_interval == 0:
                record = {
                    'nll': batch_nll.item() / self.train['X'].shape[0],
                    'iter': i
                }
                if test:
                    pred_start.record()
                    Ypred_full = []
                    Yt_full = []
                    count = 0
                    for (Xt, Yt) in self.test:
                        count += 1
                        Yt_full.append(Yt)
                        Ypred = self.gp(Xt,
                                        self.train['X'],
                                        self.train['Y'],
                                        var=False)
                        Ypred_full.append(Ypred)

                    Ypred_full = torch.cat(Ypred_full)
                    Yt_full = torch.cat(Yt_full)
                    record['rmse'] = rmse(Ypred_full, Yt_full).item()
                    pred_end.record()
                    torch.cuda.synchronize()
                    pred_time += pred_start.elapsed_time(pred_end)
                end.record()
                torch.cuda.synchronize()
                record['time'] = start.elapsed_time(end) - pred_time
                if verbose:
                    print(record)
                self.history.append(record)
        return self.history
def train_acn(info, model_dict, data_buffers, phase='train'):
    encoder_model = model_dict['encoder_model']
    prior_model = model_dict['prior_model']
    pcnn_decoder = model_dict['pcnn_decoder']
    opt = model_dict['opt']

    # add one to the rewards so that they are all positive
    # use next_states because that is the t-1 action

    if len(info['model_train_cnts']):
        train_cnt = info['model_train_cnts'][-1]
    else:
        train_cnt = 0

    num_batches = data_buffers['train'].count // info['MODEL_BATCH_SIZE']
    while train_cnt < 10000000:
        if phase == 'valid':
            encoder_model.eval()
            prior_model.eval()
            pcnn_decoder.eval()
        else:
            encoder_model.train()
            prior_model.train()
            pcnn_decoder.train()

        batch_num = 0
        data_buffers[phase].reset_unique()
        print('-------------new epoch %s------------------' % phase)
        print('num batches', num_batches)
        while data_buffers[phase].unique_available:
            opt.zero_grad()
            batch = data_buffers[phase].get_unique_minibatch(
                info['MODEL_BATCH_SIZE'])
            relative_indices = batch[-1]
            states, actions, rewards, next_states = make_state(
                batch[:-1], info['DEVICE'], info['NORM_BY'])
            next_state = next_states[:, -1:]
            bs = states.shape[0]
            #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch()
            z, u_q = encoder_model(states)

            # add the predicted codes to the input
            #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_state,
            #                                        class_condition=actions,
            #                                        float_condition=z))
            yhat_batch = encoder_model.decode(z)
            prior_model.codes[relative_indices] = u_q.detach().cpu().numpy()

            mix, u_ps, s_ps = prior_model(u_q)

            # track losses
            kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_state,
                                                      u_q, mix, u_ps, s_ps)
            loss = kl_loss + rec_loss
            # aatch size because it hasn't been added to train cnt yet

            if not phase == 'valid':
                loss.backward()
                #parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
                parameters = list(encoder_model.parameters()) + list(
                    prior_model.parameters())
                clip_grad_value_(parameters, 10)
                opt.step()
                train_cnt += bs

            if not batch_num % info['MODEL_LOG_EVERY_BATCHES']:
                print(phase, train_cnt, batch_num, kl_loss.item(),
                      rec_loss.item())
                info = add_losses(info, train_cnt, phase, kl_loss.item(),
                                  rec_loss.item())
            batch_num += 1

        if (((train_cnt - info['model_last_save']) >=
             info['MODEL_SAVE_EVERY'])):
            info = add_losses(info, train_cnt, phase, kl_loss.item(),
                              rec_loss.item())
            if phase == 'train':
                # run as valid phase and get back to here
                phase = 'valid'
            else:
                model_dict = {
                    'encoder_model': encoder_model,
                    'prior_model': prior_model,
                    'pcnn_decoder': pcnn_decoder,
                    'opt': opt
                }
                info = save_model(info, model_dict)
                phase = 'train'

    model_dict = {
        'encoder_model': encoder_model,
        'prior_model': prior_model,
        'pcnn_decoder': pcnn_decoder,
        'opt': opt
    }

    info = save_model(info, model_dict)
Beispiel #20
0
def train(latent_dim, datasource, num_actions, num_rewards, encoder, decoder,
          reward_predictor, discriminator, transition):
    batch_size = args.batch_size
    train_iters = args.train_iters
    td_lambda_coef = args.td_lambda
    td_steps = args.td_steps
    truncate_bptt = args.truncate_bptt
    enable_td = args.latent_td
    enable_latent_overshooting = args.latent_overshooting
    learning_rate = args.learning_rate
    min_prediction_horizon = args.horizon_min
    max_prediction_horizon = args.horizon_max
    finetune_reward = args.finetune_reward
    REWARD_COEF = args.reward_coef
    ACTIVATION_L1_COEF = args.activation_l1_coef
    TRANSITION_L1_COEF = args.transition_l1_coef
    counterfactual_horizon = args.counterfactual_horizon
    start_iter = args.start_iter

    opt_enc = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
    opt_dec = torch.optim.Adam(decoder.parameters(), lr=learning_rate)
    opt_trans = torch.optim.Adam(transition.parameters(), lr=learning_rate)
    opt_disc = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
    opt_pred = torch.optim.Adam(reward_predictor.parameters(),
                                lr=learning_rate)
    ts = TimeSeries('Training Model', train_iters, tensorboard=True)

    for train_iter in range(start_iter, train_iters + 1):
        if train_iter % ITERS_PER_VIDEO == 0:
            print('Evaluating networks...')
            evaluate(datasource,
                     encoder,
                     decoder,
                     transition,
                     discriminator,
                     reward_predictor,
                     latent_dim,
                     train_iter=train_iter)
            print('Saving networks to filesystem...')
            torch.save(transition.state_dict(), 'model-transition.pth')
            torch.save(encoder.state_dict(), 'model-encoder.pth')
            torch.save(decoder.state_dict(), 'model-decoder.pth')
            torch.save(discriminator.state_dict(), 'model-discriminator.pth')
            torch.save(reward_predictor.state_dict(),
                       'model-reward_predictor.pth')

        theta = train_iter / train_iters
        pred_delta = max_prediction_horizon - min_prediction_horizon
        prediction_horizon = min_prediction_horizon + int(pred_delta * theta)

        train_mode(
            [encoder, decoder, transition, discriminator, reward_predictor])

        # Train encoder/transition/decoder
        opt_enc.zero_grad()
        opt_dec.zero_grad()
        opt_trans.zero_grad()
        opt_pred.zero_grad()

        states, rewards, dones, actions = datasource.get_trajectories(
            batch_size, prediction_horizon)
        states = torch.Tensor(states).cuda()
        rewards = torch.Tensor(rewards).cuda()
        dones = torch.Tensor(dones.astype(int)).cuda()

        # Encode the initial state (using the first 3 frames)
        # Given t, t+1, t+2, encoder outputs the state at time t+1
        z = encoder(states[:, 0:3])
        z_orig = z.clone()

        # But wait, here's the problem: We can't use the encoded initial state as
        # an initial state of the dynamical system and expect the system to work
        # The dynamical system needs to have something like the Echo State Property
        # So the dynamical parts need to run long enough to reach a steady state

        # Keep track of "done" states to stop a training trajectory at the final time step
        active_mask = torch.ones(batch_size).cuda()

        loss = 0
        lo_loss = 0
        lo_z_set = {}
        # Given the state encoded at t=2, predict state at t=3, t=4, ...
        for t in range(1, prediction_horizon - 1):
            active_mask = active_mask * (1 - dones[:, t])

            # Predict reward
            expected_reward = reward_predictor(z)
            actual_reward = rewards[:, t]
            reward_difference = torch.mean(
                torch.mean(
                    (expected_reward - actual_reward)**2, dim=1) * active_mask)
            ts.collect('Rd Loss t={}'.format(t), reward_difference)
            loss += theta * REWARD_COEF * reward_difference  # Normalize by height * width

            # Reconstruction loss
            target_pixels = states[:, t]
            predicted = torch.sigmoid(decoder(z))
            rec_loss_batch = decoder_pixel_loss(target_pixels, predicted)

            if truncate_bptt and t > 1:
                z.detach_()

            rec_loss = torch.mean(rec_loss_batch * active_mask)
            ts.collect('Reconstruction t={}'.format(t), rec_loss)
            loss += rec_loss

            # Apply activation L1 loss
            #l1_values = z.abs().mean(-1).mean(-1).mean(-1)
            #l1_loss = ACTIVATION_L1_COEF * torch.mean(l1_values * active_mask)
            #ts.collect('L1 t={}'.format(t), l1_loss)
            #loss += theta * l1_loss

            # Predict transition to the next state
            onehot_a = torch.eye(num_actions)[actions[:, t]].cuda()
            new_z = transition(z, onehot_a)

            # Apply transition L1 loss
            #t_l1_values = ((new_z - z).abs().mean(-1).mean(-1).mean(-1))
            #t_l1_loss = TRANSITION_L1_COEF * torch.mean(t_l1_values * active_mask)
            #ts.collect('T-L1 t={}'.format(t), t_l1_loss)
            #loss += theta * t_l1_loss

            z = new_z

            if enable_latent_overshooting:
                # Latent Overshooting, Hafner et al.
                lo_z_set[t] = encoder(states[:, t - 1:t + 2])

                # For each previous t_left, step forward to t
                for t_left in range(1, t):
                    a = torch.eye(num_actions)[actions[:, t - 1]].cuda()
                    lo_z_set[t_left] = transition(lo_z_set[t_left], a)
                for t_a in range(2, t - 1):
                    # It's like TD but only N:1 for all N
                    predicted_activations = lo_z_set[t_a]
                    target_activations = lo_z_set[t].detach()
                    lo_loss_batch = latent_state_loss(target_activations,
                                                      predicted_activations)
                    lo_loss += td_lambda_coef * torch.mean(
                        lo_loss_batch * active_mask)

        if enable_latent_overshooting:
            ts.collect('LO total', lo_loss)
            loss += theta * lo_loss

        # COUNTERFACTUAL DISENTANGLEMENT REGULARIZATION
        # Suppose that our representation is ideally, perfectly disentangled
        # Then the PGM has no edges, the causal graph is just nodes with no relationships
        # In this case, it should be true that intervening on any one factor has no effect on the others
        # One fun way of intervening is swapping factors, a la FactorVAE
        # If we intervene on some dimensions, the other dimensions should be unaffected
        if enable_cf_shuffle_loss and train_iter % CF_REGULARIZATION_RATE == 0:
            # Counterfactual scenario A: our memory of what really happened
            z_cf_a = z.clone()
            # Counterfactual scenario B: a bizzaro world where two dimensions are swapped
            z_cf_b = z_orig
            unswapped_factor_map = torch.ones((batch_size, latent_dim)).cuda()
            for i in range(batch_size):
                idx_a = np.random.randint(latent_dim)
                idx_b = np.random.randint(latent_dim)
                unswapped_factor_map[i, idx_a] = 0
                unswapped_factor_map[i, idx_b] = 0
                z_cf_b[i, idx_a], z_cf_b[i,
                                         idx_b] = z_cf_b[i,
                                                         idx_b], z_cf_b[i,
                                                                        idx_a]
            # But we take the same actions
            for t in range(1, counterfactual_horizon):
                onehot_a = torch.eye(num_actions)[actions[:, t]].cuda()
                z_cf_b = transition(z_cf_b, onehot_a)
            # Every UNSWAPPED dimension should be as similar as possible to its bizzaro-world equivalent
            cf_loss = torch.abs(z_cf_a - z_cf_b).mean(-1).mean(
                -1) * unswapped_factor_map
            cf_loss = CF_REGULARIZATION_LAMBDA * torch.mean(
                cf_loss.mean(-1) * active_mask)
            loss += cf_loss
            ts.collect('CF Disentanglement Loss', cf_loss)

        # COUNTERFACTUAL ACTION-CONTROL REGULARIZATION
        # In difficult POMDPs, deep neural networks can suffer from learned helplessness
        # They learn, rationally, that their actions have no causal influence on the reward
        # This is undesirable: the learned model should assume that outcomes are controllable
        if enable_control_bias_loss and train_iter % CF_REGULARIZATION_RATE == 0:
            # Counterfactual scenario A: our memory of what really happened
            z_cf_a = z.clone()
            # Counterfactual scenario B: our imagination of what might have happened
            z_cf_b = z_orig
            # Instead of the regular actions, apply an alternate policy
            cf_actions = actions.copy()
            np.random.shuffle(cf_actions)
            for t in range(1, counterfactual_horizon):
                onehot_a = torch.eye(num_actions)[cf_actions[:, t]].cuda()
                z_cf_b = transition(z_cf_b, onehot_a)
            eps = .001  # for numerical stability
            cf_loss = -torch.log(
                torch.abs(z_cf_a - z_cf_b).mean(-1).mean(-1).mean(-1) + eps)
            cf_loss = CF_REGULARIZATION_LAMBDA * torch.mean(
                cf_loss * active_mask)
            loss += cf_loss
            ts.collect('CF Control Bias Loss', cf_loss)

        loss.backward()

        from torch.nn.utils.clip_grad import clip_grad_value_
        clip_grad_value_(encoder.parameters(), 0.1)
        clip_grad_value_(transition.parameters(), 0.1)
        clip_grad_value_(decoder.parameters(), 0.1)

        opt_pred.step()
        if not args.finetune_reward:
            opt_enc.step()
            opt_dec.step()
            opt_trans.step()
        ts.print_every(10)
    print(ts)
    print('Finished')
Beispiel #21
0
    def train_gp(self, seed=0, burn_iter=100, n_iter=100, lmbda=1.0, pred_interval=5, test=None, verbose=True, n_epoch=50):
        set_seed(seed)
        print('SEED=', seed)
        optimizer = opt.Adam(self.model.parameters())
        for i in range(n_iter):
            batch_nll = 0.0
            batch_elbo = 0.0
            batch_loss = 0.0
            #batch_validation = 0.0
            epoch = 0
            for (X, Y) in self.data:
                if epoch > n_epoch:
                    break
                epoch += 1
                X.to(self.device)
                Y.to(self.device)
                '''
                Z = self.vae(X)
                self.gp.data['X'] = Z
                self.gp.data['Y'] = Y
                self.gp.n_data = Z.shape[0]
                self.gp.n_dim = Z.shape[1]
                '''
                Xr = self.vae(self.vae(X), encode=False)
                self.gp.data['X'] = Xr
                #print(self.gp.data['X'].shape)
                self.gp.data['Y'] = Y
                self.gp.n_data = Xr.shape[0]
                self.gp.n_dim = Xr.shape[1]
                self.model.train()
                optimizer.zero_grad()
                delbo = self.vae.dsELBO(X, alpha=1.0, beta=1.0, gamma=1.0, verbose=False)
                nll = self.gp.NLL_batch(Xr, Y)
                #validation = rmse(self.gp(Xr,  grad=True)[0], Y)
                loss = - delbo + lmbda * nll
                #batch_validation += validation * X.shape[0]
                batch_nll += nll * X.shape[0]
                batch_elbo += delbo * X.shape[0]
                batch_loss += loss * X.shape[0]
                loss.backward()
                clip_grad_value_(self.model.parameters(), 20)
                optimizer.step()
                torch.cuda.empty_cache()
            if i % pred_interval == 0:
                record = {'nll': batch_nll.item() / self.train['X'].shape[0],
                          'elbo': batch_elbo.item() / self.train['X'].shape[0],
                          'loss': batch_loss.item() / self.train['X'].shape[0],
                          'iter': i}
                if test is not None:
                    #self.gp.data['X'] = self.vae(self.original['X'], grad=False)
                    self.gp.data['X'] = self.vae(self.vae(self.original['X'], grad=False), encode=False, grad=False)
                    #print(self.gp.data['X'].shape)
                    self.gp.data['Y'] = self.original['Y']
                    self.gp.n_data = self.gp.data['X'].shape[0]
                    self.gp.n_dim = self.gp.data['X'].shape[1]
                    Xtr = self.vae(self.vae(test['X'], grad=False), encode=False, grad=False)
                    err = rmse(self.gp(Xtr)[0], test['Y'])
                    record['rmse'] = err.item()

                if verbose:
                    print(record)

                self.history.append(record)

        '''
        for i in range(n_iter):
            batch_nll = 0.0
            epoch = 0
            for (X, Y) in self.data:
                if epoch > n_epoch:
                    break
                epoch += 1
                X.to(self.device)
                Y.to(self.device)
                Xr = self.vae(self.vae(X), encode=False)
           
                self.gp.data['X'] = Z
                self.gp.data['Y'] = Y
                self.gp.n_data = Z.shape[0]
                self.gp.n_dim = Z.shape[1]
           
                self.gp.data['X'] = Xr
                self.gp.data['Y'] = Y
                self.gp.n_data = Xr.shape[0]
                self.gp.n_dim = Xr.shape[1]
                self.gp_model.train()
                optimizer.zero_grad()
                nll = self.gp.NLL_batch(Xr, Y)
                batch_nll += nll * X.shape[0]
                nll.backward()
                optimizer.step()
                torch.cuda.empty_cache()

            if i % pred_interval == 0:
                record = {'nll': batch_nll.item() / self.train['X'].shape[0],
                          'iter': i}

                if test is not None:
                    self.gp.data['X'] = self.vae(self.original['X'])
                    self.gp.data['Y'] = self.original['Y']
                    self.gp.n_data = self.gp.data['X'].shape[0]
                    self.gp.n_dim = self.gp.data['X'].shape[1]
                    err = rmse(self.gp(self.vae(test['X']))[0], test['Y'])
                    record['rmse'] = err.item()

                if verbose:
                    print(record)

                self.history.append(record)
        '''

        return self.history
Beispiel #22
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()
        metrics_with_states: List[Tuple] = [
            (metric, {}) for metric in self.training_metrics
        ]
        self._last_train_log_step = 0

        log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        log_prefix += 'Training'

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    _, _, loss = self._compute_batch(batch,
                                                     metrics_with_states)
                    accumulated_loss += loss.item() / self.batches_per_iter
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{log_prefix}/Loss', accumulated_loss, global_step)
                log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm,
                    global_step)
                log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm,
                    global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    learning_rate = self.iter_scheduler.get_lr()[
                        0]  # type: ignore
                    log(f'{log_prefix}/LR', learning_rate, global_step)
                    self.iter_scheduler.step()  # type: ignore

                # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
                # logging train metrics
                if self.extra_training_metrics_log_interval > self._last_train_log_step:
                    self._log_metrics(log_prefix, metrics_with_states,
                                      global_step)
                    self._last_train_log_step = i
            if self._last_train_log_step != i:
                # log again at end of step, if not logged at the end of
                # step before
                self._log_metrics(log_prefix, metrics_with_states, global_step)
Beispiel #23
0
def clip_grad_value(params, clip_value=10):
    # slow down training
    clip_grad.clip_grad_value_(filter(lambda p: p.requires_grad, params),
                               clip_value=clip_value)
Beispiel #24
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                self.optimizer_alphas.zero_grad()
                self.optimizer_lambdas.zero_grad()

                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    loss = self._compute_loss(batch) / self.batches_per_iter
                    accumulated_loss += loss.item()
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(), self.max_grad_abs_val)

                log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
                log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step)
                log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step)

                if global_step >= self.iter_before_pruning:

                    pruning_step = global_step - self.iter_before_pruning

                    num_parameters = get_num_params(self.hard_concrete_modules, train=True)
                    expected_sparsity = 1. - (num_parameters / self.max_prunable)

                    if self.target_sparsity_warmup > 0:
                        factor = min(1.0, pruning_step / self.target_sparsity_warmup)
                        target_sparsity = self.target_sparsity * factor
                    else:
                        target_sparsity = self.target_sparsity

                    lagrangian_loss = self.lambda_1 * (target_sparsity - expected_sparsity)
                    lagrangian_loss += self.lambda_2 * (target_sparsity - expected_sparsity) ** 2
                    lagrangian_loss.backward()
                    log("Expected_sparsity", float(expected_sparsity), global_step)
                    log("Lagrangian_loss", lagrangian_loss.item(), global_step)
                    log("Target_sparsity", target_sparsity, global_step)
                    log("lambda_1", self.lambda_1.item(), global_step)
                    log("lambda_2", self.lambda_2.item(), global_step)

                    self.optimizer_lambdas.step()
                    self.lambdas_scheduler.step(pruning_step)

                    self.optimizer_alphas.step()
                    self.alphas_scheduler.step(pruning_step)

                # Optimize
                self.optimizer.step()
                self.lr_scheduler.step(global_step)

            # Zero the gradients when exiting a train step
            self.optimizer.zero_grad()
            self.optimizer_lambdas.zero_grad()
            self.optimizer_alphas.zero_grad()
Beispiel #25
0
 def train_gp(self,
              seed=0,
              n_iter=100,
              n_epoch=50,
              lmbda=1.0,
              pred_interval=5,
              test=False,
              verbose=True):
     set_seed(seed)
     print('SEED=', seed)
     optimizer = opt.Adam(self.model.parameters())
     start = torch.cuda.Event(enable_timing=True)
     end = torch.cuda.Event(enable_timing=True)
     pred_start = torch.cuda.Event(enable_timing=True)
     pred_end = torch.cuda.Event(enable_timing=True)
     pred_time = 0.0
     start.record()
     for i in range(n_iter):
         batch_nll = 0.0
         batch_elbo = 0.0
         batch_loss = 0.0
         epoch = 0
         for (X, Y) in self.data:
             if epoch > n_epoch:
                 break
             epoch += 1
             X.to(self.device)
             Y.to(self.device)
             Xa = self.vae(self.vae(X), encode=False)
             self.model.train()
             optimizer.zero_grad()
             delbo = self.vae.dsELBO(X,
                                     alpha=1.0,
                                     beta=1.0,
                                     gamma=1.0,
                                     verbose=False)
             nll = self.gp.NLL(Xa, Y)
             loss = -0.01 * delbo + lmbda * nll
             batch_nll += nll * X.shape[0]
             batch_elbo += delbo * X.shape[0]
             batch_loss += loss * X.shape[0]
             loss.backward()
             clip_grad_value_(self.model.parameters(), 10)
             optimizer.step()
             torch.cuda.empty_cache()
         if i % pred_interval == 0:
             torch.cuda.synchronize()
             record = {
                 'iter': i,
                 'nll': batch_nll.item() / self.train['X'].shape[0],
                 'elbo': batch_elbo.item() / self.train['X'].shape[0],
                 'loss': batch_loss.item() / self.train['X'].shape[0],
             }
             if test:
                 pred_start.record()
                 Ypred_full = []
                 Yt_full = []
                 Xr = self.vae(self.vae(self.original['X'], grad=False),
                               encode=False,
                               grad=False)
                 for (Xt, Yt) in self.test:
                     Xtr = self.vae(self.vae(Xt, grad=False),
                                    encode=False,
                                    grad=False)
                     Ypred = self.gp(Xtr, Xr, self.original['Y'], var=False)
                     Ypred_full.append(Ypred)
                     Yt_full.append(Yt)
                     del Xtr
                     torch.cuda.empty_cache()
                 Ypred_full = torch.cat(Ypred_full)
                 Yt_full = torch.cat(Yt_full)
                 record['rmse'] = rmse(Ypred_full, Yt_full).item()
                 pred_end.record()
                 torch.cuda.synchronize()
                 pred_time += pred_start.elapsed_time(pred_end)
                 del Xr
                 torch.cuda.empty_cache()
             end.record()
             torch.cuda.synchronize()
             record['time'] = start.elapsed_time(end) - pred_time
             if verbose:
                 print(record)
             self.history.append(record)
     return self.history
def clip_parameters(model_dict, clip_val=10):
    for name, model in model_dict.items():
        if 'model' in name:
            clip_grad_value_(model.parameters(), clip_val)
    return model_dict
Beispiel #27
0
def train_vqvae(train_cnt, vqvae_model, opt, info, train_data_loader,
                valid_data_loader):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < info['VQ_NUM_EXAMPLES_TO_TRAIN']:
        vqvae_model.train()
        opt.zero_grad()
        states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to(
            info['DEVICE'])
        rec = (
            2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) -
            1).to(info['DEVICE'])
        actions = torch.LongTensor(actions).to(info['DEVICE'])
        rewards = torch.LongTensor(rewards).to(info['DEVICE'])
        # dont normalize diff
        diff = (reshape_input(
            torch.FloatTensor(pred_states)[:, 1][:, None])).to(info['DEVICE'])
        x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model(
            states)
        z_q_x.retain_grad()
        rec_est = x_d[:, :info['nmix']]
        diff_est = x_d[:, info['nmix']:]
        loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss(
            rec_est,
            rec,
            nr_mix=info['NR_LOGISTIC_MIX'],
            DEVICE=info['DEVICE'])
        loss_diff = discretized_mix_logistic_loss(
            diff_est,
            diff,
            nr_mix=info['NR_LOGISTIC_MIX'],
            DEVICE=info['DEVICE'])

        loss_act = info['ALPHA_ACT'] * F.nll_loss(
            pred_actions, actions, weight=info['actions_weight'])
        loss_rewards = info['ALPHA_REW'] * F.nll_loss(
            pred_rewards, rewards, weight=info['rewards_weight'])
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())

        loss_act.backward(retain_graph=True)
        loss_rec.backward(retain_graph=True)
        loss_diff.backward(retain_graph=True)

        loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach())
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()

        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        avg_train_losses = [
            loss_rewards.item() / bs,
            loss_act.item() / bs,
            loss_rec.item() / bs,
            loss_diff.item() / bs,
            loss_2.item() / bs,
            loss_3.item() / bs
        ]
        if batches > info['VQ_MIN_BATCHES_BEFORE_SAVE']:
            if ((train_cnt - info['vq_last_save']) >= info['VQ_SAVE_EVERY']):
                info['vq_last_save'] = train_cnt
                info['vq_save_times'].append(time.time())
                avg_valid_losses = valid_vqvae(train_cnt, vqvae_model, info,
                                               valid_data_loader)
                handle_plot_ckpt(train_cnt, info, avg_train_losses,
                                 avg_valid_losses)
                filename = info[
                    'vq_model_base_filepath'] + "_%010dex.pt" % train_cnt
                print("SAVING MODEL:%s" % filename)
                print("Saving model at cnt:%s cnt since last saved:%s" %
                      (train_cnt, train_cnt - info['vq_last_save']))
                state = {
                    'vqvae_state_dict': vqvae_model.state_dict(),
                    'vq_optimizer': opt.state_dict(),
                    'vq_embedding': vqvae_model.embedding,
                    'vq_info': info,
                }
                save_checkpoint(state, filename=filename)

        train_cnt += len(states)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt