コード例 #1
0
def handle_checkpointing(train_cnt, loss_list):
    if ((train_cnt - info['last_save']) >= args.save_every):
        print("Saving model at cnt:%s cnt since last saved:%s" %
              (train_cnt, train_cnt - info['last_save']))
        info['last_save'] = train_cnt
        info['save_times'].append(time.time())
        handle_plot_ckpt(True, train_cnt, loss_list)
        filename = model_base_filepath + "_%010dex.pt" % train_cnt
        print("SAVING MODEL:%s" % filename)
        state = {
            'vqvae_state_dict': vqvae_model.state_dict(),
            'optimizer': opt.state_dict(),
            'embedding': vqvae_model.embedding,
            'info': info,
        }
        save_checkpoint(state, filename=filename)
    elif not len(info['train_cnts']):
        print("Logging: %s no previous logs" % (train_cnt))
        handle_plot_ckpt(True, train_cnt, loss_list)
    elif (train_cnt - info['last_plot']) >= args.plot_every:
        print("Calling plot at cnt:%s cnt since last plotted:%s" %
              (train_cnt, train_cnt - info['last_plot']))
        handle_plot_ckpt(True, train_cnt, loss_list)
    else:
        if (train_cnt - info['train_cnts'][-1]) >= args.log_every:
            print("Logging at cnt:%s cnt since last logged:%s" %
                  (train_cnt, train_cnt - info['train_cnts'][-1]))
            handle_plot_ckpt(False, train_cnt, loss_list)
def handle_checkpointing(train_cnt, avg_train_loss):
    if ((train_cnt - info['last_save']) >= args.save_every):
        print("Saving Model at cnt:%s cnt since last saved:%s" %
              (train_cnt, train_cnt - info['last_save']))
        info['last_save'] = train_cnt
        info['save_times'].append(time.time())
        handle_plot_ckpt(True, train_cnt, avg_train_loss)
        filename = vae_base_filepath + "_%010dex.pkl" % train_cnt
        state = {
            'vae_state_dict': vae_model.state_dict(),
            'prior_state_dict': prior_model.state_dict(),
            'pcnn_state_dict': pcnn_decoder.state_dict(),
            'optimizer': opt.state_dict(),
            'info': info,
        }
        save_checkpoint(state, filename=filename)
    elif not len(info['train_cnts']):
        print("Logging model: %s no previous logs" % (train_cnt))
        handle_plot_ckpt(False, train_cnt, avg_train_loss)
    elif (train_cnt - info['last_plot']) >= args.plot_every:
        print("Plotting Model at cnt:%s cnt since last plotted:%s" %
              (train_cnt, train_cnt - info['last_plot']))
        handle_plot_ckpt(True, train_cnt, avg_train_loss)
    else:
        if (train_cnt - info['train_cnts'][-1]) >= args.log_every:
            print("Logging Model at cnt:%s cnt since last logged:%s" %
                  (train_cnt, train_cnt - info['train_cnts'][-1]))
            handle_plot_ckpt(False, train_cnt, avg_train_loss)
コード例 #3
0
def handle_checkpoint(last_save, cnt, epoch):
    if (cnt - last_save) >= info['CHECKPOINT_EVERY_STEPS']:
        print("checkpoint")
        last_save = cnt
        state = {
            'info': info,
            'optimizer': opt.state_dict(),
            'cnt': cnt,
            'epoch': epoch,
            'policy_net_state_dict': policy_net.state_dict(),
            'target_net_state_dict': target_net.state_dict(),
        }
        filename = model_base_filepath + "_%010dq.pkl" % cnt
        save_checkpoint(state, filename)
        return last_save, filename
    else:
        return last_save, ''
コード例 #4
0
def save_vqvae(info, train_cnt, vqvae_model, opt, avg_train_losses,
               valid_batch):
    info['model_last_save'] = train_cnt
    info['model_save_times'].append(time.time())
    avg_valid_losses = valid_vqvae(train_cnt, vqvae_model, info, valid_batch)
    handle_plot_ckpt(train_cnt, info, avg_train_losses, avg_valid_losses)
    filename = info['MODEL_MODEL_BASE_FILEDIR'] + "_%010dex.pt" % train_cnt
    print("SAVING MODEL:%s" % filename)
    print("Saving model at cnt:%s cnt since last saved:%s" %
          (train_cnt, train_cnt - info['model_last_save']))
    state = {
        'model_state_dict': vqvae_model.state_dict(),
        'model_optimizer': opt.state_dict(),
        'model_embedding': vqvae_model.embedding,
        'model_info': info,
    }
    save_checkpoint(state, filename=filename)
コード例 #5
0
def save_model(info, model_dict):
    train_cnt = info['model_train_cnts'][-1]
    info['model_last_save'] = train_cnt
    info['model_save_times'].append(time.time())
    #avg_valid_losses = valid_vqvae(train_cnt, model, info, valid_batch)
    #handle_plot_ckpt(train_cnt, info, avg_train_losses, avg_valid_losses)
    # TODO - replace w valid
    #handle_plot_ckpt(train_cnt, info, avg_train_losses, avg_valid_losses)
    filename = os.path.join(info['MODEL_BASE_FILEDIR'], "_%010dex.pt"%train_cnt)
    print("Saving model at cnt:%s cnt since last saved:%s"%(train_cnt, train_cnt-info['model_last_save']))
    print(filename)
    state = {
             'model_info':info,
             }
    for (model_name, model) in model_dict.items():
        state[model_name+'_state_dict'] = model.state_dict()
    save_checkpoint(state, filename=filename)
    return info
コード例 #6
0
def handle_checkpoint(last_save, cnt, epoch, last_mean):
    if (cnt - last_save) >= info['CHECKPOINT_EVERY_STEPS']:
        print("checkpoint")
        last_save = cnt
        state = {
            'info': info,
            'optimizer': opt.state_dict(),
            'cnt': cnt,
            'epoch': epoch,
            'policy_net_state_dict': policy_net.state_dict(),
            'target_net_state_dict': target_net.state_dict(),
            'last_mean': last_mean,
        }
        filename = os.path.abspath(model_base_filepath + "_%010dq.pkl" % cnt)
        save_checkpoint(state, filename)
        buff_filename = os.path.abspath(model_base_filepath +
                                        "_%010dq_train_buffer.pkl" % cnt)
        rbuffer.save(buff_filename)
        return last_save
    else:
        return last_save
コード例 #7
0
def train_vqvae(train_cnt, vqvae_model, opt, info, train_data_loader,
                valid_data_loader):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < info['VQ_NUM_EXAMPLES_TO_TRAIN']:
        vqvae_model.train()
        opt.zero_grad()
        states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to(
            info['DEVICE'])
        rec = (
            2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) -
            1).to(info['DEVICE'])
        actions = torch.LongTensor(actions).to(info['DEVICE'])
        rewards = torch.LongTensor(rewards).to(info['DEVICE'])
        # dont normalize diff
        diff = (reshape_input(
            torch.FloatTensor(pred_states)[:, 1][:, None])).to(info['DEVICE'])
        x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model(
            states)
        z_q_x.retain_grad()
        rec_est = x_d[:, :info['nmix']]
        diff_est = x_d[:, info['nmix']:]
        loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss(
            rec_est,
            rec,
            nr_mix=info['NR_LOGISTIC_MIX'],
            DEVICE=info['DEVICE'])
        loss_diff = discretized_mix_logistic_loss(
            diff_est,
            diff,
            nr_mix=info['NR_LOGISTIC_MIX'],
            DEVICE=info['DEVICE'])

        loss_act = info['ALPHA_ACT'] * F.nll_loss(
            pred_actions, actions, weight=info['actions_weight'])
        loss_rewards = info['ALPHA_REW'] * F.nll_loss(
            pred_rewards, rewards, weight=info['rewards_weight'])
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())

        loss_act.backward(retain_graph=True)
        loss_rec.backward(retain_graph=True)
        loss_diff.backward(retain_graph=True)

        loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach())
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()

        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        avg_train_losses = [
            loss_rewards.item() / bs,
            loss_act.item() / bs,
            loss_rec.item() / bs,
            loss_diff.item() / bs,
            loss_2.item() / bs,
            loss_3.item() / bs
        ]
        if batches > info['VQ_MIN_BATCHES_BEFORE_SAVE']:
            if ((train_cnt - info['vq_last_save']) >= info['VQ_SAVE_EVERY']):
                info['vq_last_save'] = train_cnt
                info['vq_save_times'].append(time.time())
                avg_valid_losses = valid_vqvae(train_cnt, vqvae_model, info,
                                               valid_data_loader)
                handle_plot_ckpt(train_cnt, info, avg_train_losses,
                                 avg_valid_losses)
                filename = info[
                    'vq_model_base_filepath'] + "_%010dex.pt" % train_cnt
                print("SAVING MODEL:%s" % filename)
                print("Saving model at cnt:%s cnt since last saved:%s" %
                      (train_cnt, train_cnt - info['vq_last_save']))
                state = {
                    'vqvae_state_dict': vqvae_model.state_dict(),
                    'vq_optimizer': opt.state_dict(),
                    'vq_embedding': vqvae_model.embedding,
                    'vq_info': info,
                }
                save_checkpoint(state, filename=filename)

        train_cnt += len(states)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt