def main():

    args = parse_arg()
    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        env = gym.make('FeedingCooperation-v0')
        set_global_seeds(int(args['random_seed']))
        robot_actor_critic_entropy = Robot_Actor_Critic(
            sess, float(args['actor_lr']), float(args['critic_lr']),
            float(args['value_lr']), float(args['reg_factor']),
            float(args['gamma']), float(args['tau']),
            float(args['value_weight']), float(args['critic_weight']),
            float(args['actor_weight']), float(args['all_lr']),
            float(args['max_steps']), float(args['minibatch_size']))
        human_actor_critic_entropy = Human_Actor_Critic(
            sess, float(args['actor_lr']), float(args['critic_lr']),
            float(args['value_lr']), float(args['reg_factor']),
            float(args['gamma']), float(args['tau']),
            float(args['value_weight']), float(args['critic_weight']),
            float(args['actor_weight']), float(args['all_lr']),
            float(args['max_steps']), float(args['minibatch_size']))
        train(sess, env, args, robot_actor_critic_entropy,
              human_actor_critic_entropy)
        savepath = osp.join("my_model_sac_cop/", 'final')
        os.makedirs(savepath, exist_ok=True)
        savepath = osp.join(savepath, 'sacmodel')
        save_state(savepath)
def main():
    # get argument
    tf.reset_default_graph()
    arg_parser = common_arg_parser()
    args= arg_parser.parse_args()
    pp.pprint(vars(args))

    model, env = train(args)
    savepath = osp.join("my_model_cop/", 'final')
    os.makedirs(savepath, exist_ok=True)
    savepath = osp.join(savepath, 'ppomodel')
    save_state(savepath)
    env.close()

    return model
def train(sess, env, args, robot_actor_critic, human_actor_critic):
    sess.run(tf.global_variables_initializer())
    global_summary = tf.summary.FileWriter(
        'summaries/' + 'feeding_sac_all' +
        datetime.datetime.now().strftime('%d-%m-%y%H%M'), sess.graph)
    robot_actor_critic.update_target_network()
    human_actor_critic.update_target_network()

    replay_buffer = ReplayBuffer(int(args['buffer_size']))
    pbar = tqdm(total=int(args['max_steps']), dynamic_ncols=True)
    tfirststart = time.perf_counter()
    total_step = 0

    while total_step < int(args['max_steps']):
        state = env.reset()
        episode_reward = 0
        end_step = 0
        while True:
            robot_action, robot_greedy_action = robot_actor_critic.actor_predict(
                [state[:24]])
            human_action, human_greedy_action = human_actor_critic.actor_predict(
                [state[24:]])
            robot_action = robot_action[0]
            robot_greedy_action = robot_greedy_action[0]
            human_action = human_action[0]
            human_greedy_action = human_greedy_action[0]
            cop_action = np.concatenate([robot_action, human_action], axis=0)
            state2, reward, done, info = env.step(cop_action)
            episode_reward += reward
            end_step += 1
            total_step += 1

            replay_buffer.add(state, robot_action, human_action, reward,
                              state2, done)

            state = state2

            if total_step > 100 * int(args['minibatch_size']):
                batch_state, batch_robot_actions, batch_human_actions, batch_rewards, batch_state2, batch_dones = replay_buffer.sample(
                    int(args['minibatch_size']))
                batch_state = np.array(batch_state)
                batch_state2 = np.array(batch_state2)
                robot_actor_loss, robot_critic_loss, robot_value_loss, robot_all_loss, _ = robot_actor_critic.all_train(
                    batch_state[:, :24], batch_state2[:, :24],
                    batch_robot_actions, batch_rewards, batch_dones)
                robot_actor_critic.update_target_network()
                human_actor_loss, human_critic_loss, human_value_loss, human_all_loss, _ = human_actor_critic.all_train(
                    batch_state[:, 24:], batch_state2[:, 24:],
                    batch_human_actions, batch_rewards, batch_dones)
                human_actor_critic.update_target_network()

                summary = tf.Summary()
                summary.value.add(tag='robot_loss/value_loss',
                                  simple_value=robot_value_loss)
                summary.value.add(tag='robot_loss/critic_loss',
                                  simple_value=robot_critic_loss)
                summary.value.add(tag='robot_loss/actor_loss',
                                  simple_value=robot_actor_loss)
                summary.value.add(tag='robot_loss/total_loss',
                                  simple_value=robot_all_loss)
                summary.value.add(tag='human_loss/value_loss',
                                  simple_value=human_value_loss)
                summary.value.add(tag='human_loss/critic_loss',
                                  simple_value=human_critic_loss)
                summary.value.add(tag='human_loss/actor_loss',
                                  simple_value=human_actor_loss)
                summary.value.add(tag='human_loss/total_loss',
                                  simple_value=human_all_loss)
                global_summary.add_summary(summary, total_step)
                global_summary.flush()

            if total_step % 1000000 == 0 and total_step != 0:
                tnow = time.perf_counter()
                print('consume time', tnow - tfirststart)
                savepath = osp.join("my_model_sac_cop/", '%.5i' % total_step)
                os.makedirs(savepath, exist_ok=True)
                savepath = osp.join(savepath, 'sacmodel')
                print('Saving to', savepath)
                save_state(savepath)

            if done:
                success_time = env.success_time()
                fall_time = env.fall_times()
                msg = 'step: {},episode reward: {},episode len: {},success_time: {},fall_time: {}'
                pbar.update(total_step)
                pbar.set_description(
                    msg.format(total_step, episode_reward, end_step,
                               success_time, fall_time))
                summary = tf.Summary()
                summary.value.add(tag='Perf/Reward',
                                  simple_value=episode_reward)
                summary.value.add(tag='Perf/episode_len',
                                  simple_value=end_step)
                summary.value.add(tag='Perf/success_time',
                                  simple_value=success_time)
                summary.value.add(tag='Perf/fall_time', simple_value=fall_time)
                global_summary.add_summary(summary, total_step)
                global_summary.flush()
                break
def learn(env,
          total_timesteps,
          seed=None,
          nsteps=1024,
          ent_coef=0.01,
          lr=0.01,
          vf_coef=0.5,
          p_coef=1.0,
          max_grad_norm=None,
          gamma=0.99,
          lam=0.95,
          nminibatches=15,
          noptepochs=4,
          cliprange=0.2,
          save_interval=100,
          copeoperation=False,
          human_ent_coef=0.01,
          human_vf_coef=0.5,
          human_p_coef=1.0):

    set_global_seeds(seed)
    sess = get_session()
    global_summary = tf.summary.FileWriter(
        'summaries/' + 'feeding' +
        datetime.datetime.now().strftime('%d-%m-%y%H%M'), sess.graph)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)

    # Get the nb of env
    nenvs = env.num_envs
    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    if copeoperation == True:
        human_model = Model(env=env,
                            nbatch_act=nenvs,
                            nbatch_train=nbatch_train,
                            ent_coef=human_ent_coef,
                            vf_coef=human_vf_coef,
                            p_coef=human_p_coef,
                            max_grad_norm=max_grad_norm,
                            human=True,
                            robot=False)
        robot_model = Model(env=env,
                            nbatch_act=nenvs,
                            nbatch_train=nbatch_train,
                            ent_coef=ent_coef,
                            vf_coef=vf_coef,
                            p_coef=p_coef,
                            max_grad_norm=max_grad_norm,
                            human=False,
                            robot=True)

    if copeoperation == False:
        model = Model(env=env,
                      nbatch_act=nenvs,
                      nbatch_train=nbatch_train,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      p_coef=p_coef,
                      max_grad_norm=max_grad_norm)
    initialize()

    # Instantiate the runner object
    if copeoperation == True:
        runner = Runner(env=env,
                        model=None,
                        nsteps=nsteps,
                        gamma=gamma,
                        lam=lam,
                        human_model=human_model,
                        robot_model=robot_model)
    if copeoperation == False:
        runner = Runner(env=env,
                        model=model,
                        nsteps=nsteps,
                        gamma=gamma,
                        lam=lam)

    epinfobuf = deque(maxlen=10)  #recent 10 episode
    pbar = tqdm(total=total_timesteps, dynamic_ncols=True)

    tfirststart = time.perf_counter()

    nupdates = total_timesteps // nbatch
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        # Start timer
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)

        # Get minibatch
        if copeoperation == False:
            obs, returns, masks, actions, values, neglogpacs, epinfos = runner.run(
            )
        if copeoperation == True:
            obs, human_returns, robot_returns, masks, human_actions, robot_actions, human_values, robot_values, human_neglogpacs, robot_neglogpacs, epinfos = runner.coop_run(
            )
        epinfobuf.extend(epinfos)
        mblossvals = []
        human_mblossvals = []
        robot_mblossvals = []
        inds = np.arange(nbatch)
        for _ in range(noptepochs):
            # Randomize the indexes
            np.random.shuffle(inds)
            for start in range(0, nbatch, nbatch_train):
                end = start + nbatch_train
                mbinds = inds[start:end]
                if copeoperation == True:
                    human_slices = (arr[mbinds]
                                    for arr in (obs[:, 24:], human_returns,
                                                human_actions, human_values,
                                                human_neglogpacs))
                    robot_slices = (arr[mbinds]
                                    for arr in (obs[:, :24], robot_returns,
                                                robot_actions, robot_values,
                                                robot_neglogpacs))
                    human_mblossvals.append(
                        human_model.train(lrnow, cliprangenow, *human_slices))
                    robot_mblossvals.append(
                        robot_model.train(lrnow, cliprangenow, *robot_slices))
                if copeoperation == False:
                    slices = (arr[mbinds] for arr in (obs, returns, actions,
                                                      values, neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow,
                                                  *slices))  #None
        # Feedforward --> get losses --> update
        if copeoperation == True:
            human_lossvals = np.mean(human_mblossvals, axis=0)
            robot_lossvals = np.mean(robot_mblossvals, axis=0)
        if copeoperation == False:
            lossvals = np.mean(mblossvals, axis=0)
        summary = tf.Summary()
        if copeoperation == True:
            human_ev = explained_variance(human_values, human_returns)
            robot_ev = explained_variance(robot_values, robot_returns)
        if copeoperation == False:
            ev = explained_variance(values, returns)
        performance_r = np.mean([epinfo['r'] for epinfo in epinfobuf])
        performance_len = np.mean([epinfo['l'] for epinfo in epinfobuf])
        success_time = np.mean(
            [epinfo['success_time'] for epinfo in epinfobuf])
        fall_time = np.mean([epinfo['fall_time'] for epinfo in epinfobuf])
        summary.value.add(tag='Perf/Reward', simple_value=performance_r)
        summary.value.add(tag='Perf/episode_len', simple_value=performance_len)
        summary.value.add(tag='Perf/success_time', simple_value=success_time)
        summary.value.add(tag='Perf/fall_time', simple_value=fall_time)
        if copeoperation == True:
            summary.value.add(tag='Perf/human_explained_variance',
                              simple_value=float(human_ev))
            summary.value.add(tag='Perf/robot_explained_variance',
                              simple_value=float(robot_ev))
        if copeoperation == False:
            summary.value.add(tag='Perf/explained_variance',
                              simple_value=float(ev))
        if copeoperation == True:
            for (human_lossval, human_lossname) in zip(human_lossvals,
                                                       human_model.loss_names):
                if human_lossname == 'grad_norm':
                    summary.value.add(tag='grad/' + human_lossname,
                                      simple_value=human_lossval)
                else:
                    summary.value.add(tag='human_loss/' + human_lossname,
                                      simple_value=human_lossval)
            for (robot_lossval, robot_lossname) in zip(robot_lossvals,
                                                       robot_model.loss_names):
                if robot_lossname == 'grad_norm':
                    summary.value.add(tag='grad/' + robot_lossname,
                                      simple_value=robot_lossval)
                else:
                    summary.value.add(tag='robot_loss/' + robot_lossname,
                                      simple_value=robot_lossval)
        if copeoperation == False:
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                if lossname == 'grad_norm':
                    summary.value.add(tag='grad/' + lossname,
                                      simple_value=lossval)
                else:
                    summary.value.add(tag='loss/' + lossname,
                                      simple_value=lossval)

        global_summary.add_summary(summary, int(update * nbatch))
        global_summary.flush()
        print('finish one update')
        if update % 10 == 0:
            msg = 'step: {},episode reward: {},episode len: {},success_time: {},fall_time: {}'
            pbar.update(update * nbatch)
            pbar.set_description(
                msg.format(update * nbatch, performance_r, performance_len,
                           success_time, fall_time))

        if update % save_interval == 0:
            tnow = time.perf_counter()
            print('consume time', tnow - tfirststart)
            if copeoperation == True:
                savepath = osp.join("my_model_cop/", '%.5i' % update)
            if copeoperation == False:
                savepath = osp.join("my_model/", '%.5i' % update)
            os.makedirs(savepath, exist_ok=True)
            savepath = osp.join(savepath, 'ppomodel')
            print('Saving to', savepath)
            save_state(savepath)
    pbar.close()

    return model
Beispiel #5
0
def main(args):

    # cfg_file = os.path.join(args.example_config_path, args.primitive) + ".yaml"
    cfg = get_vae_defaults()
    # cfg.merge_from_file(cfg_file)
    cfg.freeze()

    batch_size = args.batch_size
    dataset_size = args.total_data_size

    if args.experiment_name is None:
        experiment_name = args.model_name
    else:
        experiment_name = args.experiment_name

    if not os.path.exists(os.path.join(args.log_dir, experiment_name)):
        os.makedirs(os.path.join(args.log_dir, experiment_name))

    description_txt = raw_input('Please enter experiment notes: \n')
    if isinstance(description_txt, str):
        with open(
                os.path.join(args.log_dir, experiment_name,
                             experiment_name + '_description.txt'), 'wb') as f:
            f.write(description_txt)

    writer = SummaryWriter(os.path.join(args.log_dir, experiment_name))

    # torch_seed = np.random.randint(low=0, high=1000)
    # np_seed = np.random.randint(low=0, high=1000)
    torch_seed = 0
    np_seed = 0

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    trained_model_path = os.path.join(args.model_path, args.model_name)
    if not os.path.exists(trained_model_path):
        os.makedirs(trained_model_path)

    if args.task == 'contact':
        if args.start_rep == 'keypoints':
            start_dim = 24
        elif args.start_rep == 'pose':
            start_dim = 7

        if args.goal_rep == 'keypoints':
            goal_dim = 24
        elif args.goal_rep == 'pose':
            goal_dim = 7

        if args.skill_type == 'pull':
            # + 7 because single arm palm pose
            input_dim = start_dim + goal_dim + 7
        else:
            # + 14 because both arms palm pose
            input_dim = start_dim + goal_dim + 14
        output_dim = 7
        decoder_input_dim = start_dim + goal_dim

        vae = VAE(input_dim,
                  output_dim,
                  args.latent_dimension,
                  decoder_input_dim,
                  hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP,
                  lr=args.learning_rate)
    elif args.task == 'goal':
        if args.start_rep == 'keypoints':
            start_dim = 24
        elif args.start_rep == 'pose':
            start_dim = 7

        if args.goal_rep == 'keypoints':
            goal_dim = 24
        elif args.goal_rep == 'pose':
            goal_dim = 7

        input_dim = start_dim + goal_dim
        output_dim = goal_dim
        decoder_input_dim = start_dim
        vae = GoalVAE(input_dim,
                      output_dim,
                      args.latent_dimension,
                      decoder_input_dim,
                      hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP,
                      lr=args.learning_rate)
    elif args.task == 'transformation':
        input_dim = args.input_dimension
        output_dim = args.output_dimension
        decoder_input_dim = args.input_dimension - args.output_dimension
        vae = GoalVAE(input_dim,
                      output_dim,
                      args.latent_dimension,
                      decoder_input_dim,
                      hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP,
                      lr=args.learning_rate)
    else:
        raise ValueError('training task not recognized')

    if torch.cuda.is_available():
        vae.encoder.cuda()
        vae.decoder.cuda()

    if args.start_epoch > 0:
        start_epoch = args.start_epoch
        num_epochs = args.num_epochs
        fname = os.path.join(
            trained_model_path,
            args.model_name + '_epoch_%d.pt' % args.start_epoch)
        torch_seed, np_seed = load_seed(fname)
        load_net_state(vae, fname)
        load_opt_state(vae, fname)
        args = load_args(fname)
        args.start_epoch = start_epoch
        args.num_epochs = num_epochs
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)

    data_dir = args.data_dir
    data_loader = DataLoader(data_dir=data_dir)

    data_loader.create_random_ordering(size=dataset_size)

    dataset = data_loader.load_dataset(start_rep=args.start_rep,
                                       goal_rep=args.goal_rep,
                                       task=args.task)

    total_loss = []
    start_time = time.time()
    print('Saving models to: ' + trained_model_path)
    kl_weight = 1.0
    print('Starting on epoch: ' + str(args.start_epoch))

    for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs):
        print('Epoch: ' + str(epoch))
        epoch_total_loss = 0
        epoch_kl_loss = 0
        epoch_pos_loss = 0
        epoch_ori_loss = 0
        epoch_recon_loss = 0
        kl_coeff = 1 - kl_weight
        kl_weight = args.kl_anneal_rate * kl_weight
        print('KL coeff: ' + str(kl_coeff))
        for i in range(0, dataset_size, batch_size):
            vae.optimizer.zero_grad()

            input_batch, decoder_input_batch, target_batch = \
                data_loader.sample_batch(dataset, i, batch_size)
            input_batch = to_var(torch.from_numpy(input_batch))
            decoder_input_batch = to_var(torch.from_numpy(decoder_input_batch))

            z, recon_mu, z_mu, z_logvar = vae.forward(input_batch,
                                                      decoder_input_batch)
            kl_loss = vae.kl_loss(z_mu, z_logvar)

            if args.task == 'contact':
                output_r, output_l = recon_mu
                if args.skill_type == 'grasp':
                    target_batch_right = to_var(
                        torch.from_numpy(target_batch[:, 0]))
                    target_batch_left = to_var(
                        torch.from_numpy(target_batch[:, 1]))

                    pos_loss_right = vae.mse(output_r[:, :3],
                                             target_batch_right[:, :3])
                    ori_loss_right = vae.rotation_loss(
                        output_r[:, 3:], target_batch_right[:, 3:])

                    pos_loss_left = vae.mse(output_l[:, :3],
                                            target_batch_left[:, :3])
                    ori_loss_left = vae.rotation_loss(output_l[:, 3:],
                                                      target_batch_left[:, 3:])

                    pos_loss = pos_loss_left + pos_loss_right
                    ori_loss = ori_loss_left + ori_loss_right
                elif args.skill_type == 'pull':
                    target_batch = to_var(
                        torch.from_numpy(target_batch.squeeze()))

                    #TODO add flags for when we're training both arms
                    # output = recon_mu[0]  # right arm is index [0]
                    # output = recon_mu[1]  # left arm is index [1]

                    pos_loss_right = vae.mse(output_r[:, :3],
                                             target_batch[:, :3])
                    ori_loss_right = vae.rotation_loss(output_r[:, 3:],
                                                       target_batch[:, 3:])

                    pos_loss = pos_loss_right
                    ori_loss = ori_loss_right

            elif args.task == 'goal':
                target_batch = to_var(torch.from_numpy(target_batch.squeeze()))
                output = recon_mu
                if args.goal_rep == 'pose':
                    pos_loss = vae.mse(output[:, :3], target_batch[:, :3])
                    ori_loss = vae.rotation_loss(output[:, 3:],
                                                 target_batch[:, 3:])
                elif args.goal_rep == 'keypoints':
                    pos_loss = vae.mse(output, target_batch)
                    ori_loss = torch.zeros(pos_loss.shape)

            elif args.task == 'transformation':
                target_batch = to_var(torch.from_numpy(target_batch.squeeze()))
                output = recon_mu
                pos_loss = vae.mse(output[:, :3], target_batch[:, :3])
                ori_loss = vae.rotation_loss(output[:, 3:], target_batch[:,
                                                                         3:])

            recon_loss = pos_loss + ori_loss

            loss = kl_coeff * kl_loss + recon_loss
            loss.backward()
            vae.optimizer.step()

            epoch_total_loss = epoch_total_loss + loss.data
            epoch_kl_loss = epoch_kl_loss + kl_loss.data
            epoch_pos_loss = epoch_pos_loss + pos_loss.data
            epoch_ori_loss = epoch_ori_loss + ori_loss.data
            epoch_recon_loss = epoch_recon_loss + recon_loss.data

            writer.add_scalar('loss/train/ori_loss', ori_loss.data, i)
            writer.add_scalar('loss/train/pos_loss', pos_loss.data, i)
            writer.add_scalar('loss/train/kl_loss', kl_loss.data, i)

            if (i / batch_size) % args.batch_freq == 0:
                if args.skill_type == 'pull' or args.task == 'goal' or args.task == 'transformation':
                    print(
                        'Train Epoch: %d [%d/%d (%f)]\tLoss: %f\tKL: %f\tPos: %f\t Ori: %f'
                        % (epoch, i, dataset_size,
                           100.0 * i / dataset_size / batch_size, loss.item(),
                           kl_loss.item(), pos_loss.item(), ori_loss.item()))
                elif args.skill_type == 'grasp' and args.task == 'contact':
                    print(
                        'Train Epoch: %d [%d/%d (%f)]\tLoss: %f\tKL: %f\tR Pos: %f\t R Ori: %f\tL Pos: %f\tL Ori: %f'
                        % (epoch, i, dataset_size, 100.0 * i / dataset_size /
                           batch_size, loss.item(), kl_loss.item(),
                           pos_loss_right.item(), ori_loss_right.item(),
                           pos_loss_left.item(), ori_loss_left.item()))
        print(' --avgerage loss: ')
        print(epoch_total_loss / (dataset_size / batch_size))
        loss_dict = {
            'epoch_total': epoch_total_loss / (dataset_size / batch_size),
            'epoch_kl': epoch_kl_loss / (dataset_size / batch_size),
            'epoch_pos': epoch_pos_loss / (dataset_size / batch_size),
            'epoch_ori': epoch_ori_loss / (dataset_size / batch_size),
            'epoch_recon': epoch_recon_loss / (dataset_size / batch_size)
        }
        total_loss.append(loss_dict)

        if epoch % args.save_freq == 0:
            print('\n--Saving model\n')
            print('time: ' + str(time.time() - start_time))

            save_state(net=vae,
                       torch_seed=torch_seed,
                       np_seed=np_seed,
                       args=args,
                       fname=os.path.join(
                           trained_model_path,
                           args.model_name + '_epoch_' + str(epoch) + '.pt'))

            np.savez(os.path.join(
                trained_model_path,
                args.model_name + '_epoch_' + str(epoch) + '_loss.npz'),
                     loss=np.asarray(total_loss))

    print('Done!')
    save_state(net=vae,
               torch_seed=torch_seed,
               np_seed=np_seed,
               args=args,
               fname=os.path.join(
                   trained_model_path,
                   args.model_name + '_epoch_' + str(epoch) + '.pt'))
Beispiel #6
0
                print("\n", "--" * 20)
                output = model(images)  # b, t, a
                print("output: ")
                pred = output.argmax(-1).cpu().numpy()  # b,t
                print(label_map.decode(pred, raw=False))
                print("label: ")
                print(label_map.decode_label(labels, label_lens))
                print("--" * 20)
                model.train()

        optimizer.zero_grad()
        output = model(images)
        probs = output.transpose(0, 1).contiguous().cuda()
        label_size = label_lens
        probs_size = torch.IntTensor([probs.size(0)] * probs.size(1))
        probs.requires_grad_(True)
        loss = ctc_loss(probs, labels, probs_size, label_size)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if step % print_every == 0:
            print("step: %d, loss: %.5f" % (step, total_loss / print_every))
            total_loss = 0

        if step % save_state_every == 0:
            save_state(ckpt_dir, step, model, optimizer)
            accuracy = test_model(test_dataloader, label_map, model)

        step += 1