コード例 #1
0
def train(context_input, node_input, edge_input, type_input, aspect_input,
          aspect_output, aspect_model, aspect_optimizer):
    aspect_optimizer.zero_grad()

    context_input = context_input.to(device)
    node_input = node_input.to(device)
    edge_input = edge_input.to(device)
    type_input = type_input.to(device)
    aspect_input = aspect_input.to(device)
    aspect_output = aspect_output.to(device)

    # context and graph encoder
    context_embed, hidden = aspect_model.forward_context(context_input)

    graph_embed = aspect_model.forward_graph(node_input, edge_input,
                                             type_input)

    decoder_generate, decoder_hidden = aspect_model(context_embed, hidden,
                                                    aspect_input, graph_embed)

    mask = torch.ne(aspect_output, PAD_ID)
    loss = margin_loss(decoder_generate, aspect_output, mask)
    loss.backward()

    clip = 5.0
    mc = torch.nn.utils.clip_grad_norm_(
        filter(lambda p: p.requires_grad, aspect_model.parameters()), clip)

    aspect_optimizer.step()

    return loss.item()
コード例 #2
0
ファイル: test_loss.py プロジェクト: zx-/CapsNet
    def test_sanity_check(self):
        with self.test_session():
            np.random.seed(1)
            prediction = np.random.rand(5, 10, 16).astype(np.float32)
            target = np.eye(10)[:5].astype(np.float32)

            output = loss.margin_loss(prediction, target)

            self.assertShapeEqual(np.array(1), output)
            self.assertAllGreater(output.eval(), 100)
コード例 #3
0
ファイル: test_loss.py プロジェクト: zx-/CapsNet
    def test_false_vector(self):
        with self.test_session():
            prediction = np.zeros((1, 1, 10)).astype(np.float32)
            prediction[0, 0, 5] = 1

            target = np.zeros([1, 1]).astype(np.float32)

            output = loss.margin_loss(prediction, target)

            self.assertAllCloseAccordingToType(output.eval(), 0.405)
コード例 #4
0
ファイル: test_loss.py プロジェクト: zx-/CapsNet
    def test_true_vector(self):
        with self.test_session():
            prediction = np.zeros((1, 1, 10)).astype(np.float32)
            prediction[0, 0, 5] = 1

            target = np.ones([1, 1]).astype(np.float32)

            output = loss.margin_loss(prediction, target)

            self.assertAllEqual(output.eval(), 0.0)
コード例 #5
0
ファイル: test_loss.py プロジェクト: zx-/CapsNet
    def test_false_true_vector(self):
        with self.test_session():
            prediction = np.zeros((1, 2, 10)).astype(np.float32)
            prediction[0, 0, 5] = 0.5  # 0.16
            prediction[0, 1, 5] = 1  # 0.405 loss

            target = np.zeros([1, 2]).astype(np.float32)
            target[0, 0] = 1

            output = loss.margin_loss(prediction, target)

            self.assertAllCloseAccordingToType(output.eval(), 0.565)
コード例 #6
0
ファイル: test_loss.py プロジェクト: zx-/CapsNet
    def test_batch(self):
        with self.test_session():
            prediction = np.zeros((1, 2, 10)).astype(np.float32)
            prediction[0, 0, 5] = 0.5  # 0.16
            prediction[0, 1, 5] = 1  # 0.405 loss

            target = np.zeros([1, 2]).astype(np.float32)
            target[0, 0] = 1

            prediction = np.tile(prediction, (2, 1, 1))
            target = np.tile(target, (2, 1))
            assert prediction.shape == (2, 2, 10)
            assert target.shape == (2, 2)

            output = loss.margin_loss(prediction, target)

            self.assertAllCloseAccordingToType(output.eval(), 0.565 * 2)
コード例 #7
0
def evaluate(context_input, node_input, edge_input, type_input, aspect_input,
             aspect_output, aspect_model):
    aspect_model.eval()

    context_input = context_input.to(device)
    node_input = node_input.to(device)
    edge_input = edge_input.to(device)
    type_input = type_input.to(device)
    aspect_input = aspect_input.to(device)
    aspect_output = aspect_output.to(device)

    # context and graph encoder
    context_embed, hidden = aspect_model.forward_context(context_input)

    graph_embed = aspect_model.forward_graph(node_input, edge_input,
                                             type_input)

    decoder_generate, decoder_hidden = aspect_model(context_embed, hidden,
                                                    aspect_input, graph_embed)

    mask = torch.ne(aspect_output, PAD_ID)
    loss = margin_loss(decoder_generate, aspect_output, mask)

    return loss.item()
コード例 #8
0
tf.logging.info("input dimension:{}".format(X_embedding.get_shape()))

if args.tf_model_type == 'capsule-A':
    poses, activations = capsule_model_A(X_embedding, args.num_classes)
if args.tf_model_type == 'capsule-B':
    poses, activations = capsule_model_B(X_embedding, args.num_classes)
if args.tf_model_type == 'CNN':
    poses, activations = baseline_model_cnn(X_embedding, args.num_classes)
if args.tf_model_type == 'KIMCNN':
    poses, activations = baseline_model_kimcnn(X_embedding, args.max_sent,
                                               args.num_classes)

if args.tf_loss_type == 'spread_loss':
    loss = spread_loss(y, activations, margin)
if args.tf_loss_type == 'margin_loss':
    loss = margin_loss(y, activations)
if args.tf_loss_type == 'cross_entropy':
    loss = cross_entropy(y, activations)

y_pred = tf.argmax(activations, axis=1, name="y_proba")
correct = tf.equal(tf.argmax(y, axis=1), y_pred, name="correct")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")
# tf.summary.scalar('accuracy', accuracy)
# merged = tf.summary.merge_all()
# writer = tf.summary.FileWriter('/tmp/writer_log')

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss, name="training_op")
gradients, variables = zip(*optimizer.compute_gradients(loss))

grad_check = [
コード例 #9
0
ファイル: trainer.py プロジェクト: david-riser/rlmp
    def pretrain(self, steps):
        """ Pretrain the online network using the expert buffer.
        """

        pretrain_loss = 0.
        for step in range(steps):

            e_transitions, e_weights, e_indices = self.expert_buffer.sample(
                self.config['expert_batch_size'], beta=1.)

            (e_states, e_actions, e_rewards, e_next_states,
             e_discounted_rewards, e_nth_states, e_dones,
             e_ns) = expand_transitions(
                 e_transitions,
                 torchify=True,
                 state_transformer=self.state_transformer)

            e_loss = ntd_loss(online_model=self.online_network,
                              target_model=self.target_network,
                              states=e_states,
                              actions=e_actions,
                              next_states=e_next_states,
                              rewards=e_rewards,
                              dones=e_dones,
                              gamma=0.99,
                              n=1)
            e_weights = torch.FloatTensor(e_weights).to(self.device)
            e_loss = e_loss * e_weights
            e_priorities = e_loss + 1e-5
            e_priorities = e_priorities.detach().cpu().numpy()
            self.expert_buffer.update_priorities(e_priorities, e_indices)
            e_loss = e_loss.mean()

            if self.config['n_steps'] > 1:
                e_nstep_loss = ntd_loss(online_model=self.online_network,
                                        target_model=self.target_network,
                                        states=e_states,
                                        actions=e_actions,
                                        next_states=e_nth_states,
                                        rewards=e_discounted_rewards,
                                        dones=e_dones,
                                        gamma=0.99,
                                        n=e_ns)
                e_nstep_loss = e_nstep_loss.mean()
                e_loss += e_nstep_loss

            q_values = self.online_network(e_states)
            e_loss += torch.mean(margin_loss(q_values, e_actions))
            pretrain_loss += e_loss.detach().cpu().numpy()

            self.optimizer.zero_grad()
            e_loss.backward()
            self.optimizer.step()

            if step % 10 == 0:
                print("Step: {0}, Loss: {1:6.4f}".format(
                    step, pretrain_loss / 10))
                wandb.log({"pretrain_loss": pretrain_loss / 10})
                pretrain_loss = 0.

        self.target_network.load_state_dict(self.online_network.state_dict())
コード例 #10
0
ファイル: trainer.py プロジェクト: david-riser/rlmp
    def train(self):
        """ Train the online network using the n-step loss. 
        """
        env = self.env_builder()
        self.state = env.reset()
        self.prime_buffer(env)

        self.step = 0
        score = 0
        for epoch in range(self.config['n_epochs']):
            epoch_loss = []
            start_time = time.time()
            for batch in range(self.config['n_batches_per_epoch']):

                scores = self.evaluator.evaluate(self.step)
                if scores is not None and self.batchsize_bandit is not None:

                    if self.step > 0:
                        reward = np.median(scores)
                        self.batchsize_bandit.step(reward)

                    batch_size, expert_batch_size = self.batchsize_bandit.sample(
                    )
                    self.config['batch_size'] = batch_size
                    self.config['expert_batch_size'] = expert_batch_size

                    wandb.log({
                        "bandit_batch_size": batch_size,
                        "bandit_expert_batch_size": expert_batch_size,
                        "bandit_values": self.batchsize_bandit.values
                    })

                # Choose an action based on the current state and
                # according to an epsilon-greedy policy.
                epsilon = self.epsilon_schedule.value(self.step)
                if random.random() < epsilon:
                    action = env.action_space.sample()
                else:
                    action = self.action_transformer(
                        self.online_network(self.state_transformer(
                            self.state)))

                # Update the current state of the environment by taking
                # the action and building the current transition to be
                # added to the n-step buffer.  These states are only added
                # to the replay buffer after a delay of n-steps.
                next_state, reward, done, info = env.step(action)
                current_trans = Transition(state=self.state,
                                           action=action,
                                           next_state=next_state,
                                           reward=reward,
                                           discounted_reward=None,
                                           nth_state=None,
                                           done=done,
                                           n=None)

                # Now use the contents of the n-step buffer to construct
                # the delayed transition and add that to the prioritized
                # replay buffer to be sampled for learning.
                (delayed_states, delayed_actions, delayed_rewards,
                 delayed_next_states, delayed_discounted_rewards,
                 delayed_nth_states, delayed_dones,
                 delayed_ns) = expand_transitions(self.nstep_buffer,
                                                  torchify=False)

                # Ensure that if the current episode has ended the last
                # few transitions get added correctly to the buffer.
                if not current_trans.done:
                    delayed_trans = Transition(
                        state=delayed_states[0],
                        action=delayed_actions[0],
                        reward=delayed_rewards[0],
                        next_state=delayed_next_states[0],
                        discounted_reward=np.sum([
                            reward * self.config['gamma']**i
                            for i, reward in enumerate(delayed_rewards)
                        ]),
                        nth_state=self.state,
                        done=done,
                        n=self.config['n_steps'])
                    self.buffer.add(delayed_trans)

                else:
                    for i in range(self.config['n_steps']):
                        delayed_trans = Transition(
                            state=delayed_states[i],
                            action=delayed_actions[i],
                            reward=delayed_rewards[i],
                            next_state=delayed_next_states[i],
                            discounted_reward=np.sum([
                                reward * self.config['gamma']**j
                                for j, reward in enumerate(delayed_rewards[i:])
                            ]),
                            nth_state=self.state,
                            done=done,
                            n=self.config['n_steps'] - i)
                        self.buffer.add(delayed_trans)

                # Now that we have used the buffer, we can add the current
                # transition to the queue.  Update the current state of the
                # environment.
                self.nstep_buffer.append(current_trans)
                if len(self.nstep_buffer) > self.config['n_steps']:
                    _ = self.nstep_buffer.pop(0)
                self.state = next_state

                beta = self.beta_schedule.value(self.step)
                if len(self.buffer) >= self.config[
                        'batch_size'] and self.config['batch_size'] > 0:
                    # Sample a batch of experience from the replay buffer and
                    # train with the n-step TD loss.
                    transitions, weights, indices = self.buffer.sample(
                        self.config['batch_size'], beta)
                    (states, actions, rewards, next_states, discounted_rewards,
                     nth_states, dones, ns) = expand_transitions(
                         transitions,
                         torchify=True,
                         state_transformer=self.state_transformer)

                    # Calculate the loss per transition.  This is not
                    # aggregated so that we can make the importance sampling
                    # correction to the loss.
                    #
                    # First we calculate the loss for 1-step ahead, then if
                    # required, we look ahead n-steps and add that to our loss.
                    # Importance sampling weights are based on the 1-step loss.
                    loss = ntd_loss(online_model=self.online_network,
                                    target_model=self.target_network,
                                    states=states,
                                    actions=actions,
                                    next_states=next_states,
                                    rewards=rewards,
                                    dones=dones,
                                    gamma=0.99,
                                    n=1)
                    weights = torch.FloatTensor(weights).to(self.device)
                    loss = loss * weights
                    priorities = loss + 1e-5
                    priorities = priorities.detach().cpu().numpy()
                    self.buffer.update_priorities(priorities, indices)
                    loss = loss.mean()

                    if self.config['n_steps'] > 1:
                        nstep_loss = ntd_loss(online_model=self.online_network,
                                              target_model=self.target_network,
                                              states=states,
                                              actions=actions,
                                              next_states=nth_states,
                                              rewards=discounted_rewards,
                                              dones=dones,
                                              gamma=0.99,
                                              n=ns)
                        nstep_loss = nstep_loss.mean()
                        loss += nstep_loss

                # Maybe we have an expert buffer, if so we should train some
                # samples from that expert buffer.
                if self.expert_buffer is not None and self.config[
                        'expert_batch_size'] > 0:
                    e_transitions, e_weights, e_indices = self.expert_buffer.sample(
                        self.config['expert_batch_size'], beta)

                    (e_states, e_actions, e_rewards, e_next_states,
                     e_discounted_rewards, e_nth_states, e_dones,
                     e_ns) = expand_transitions(
                         e_transitions,
                         torchify=True,
                         state_transformer=self.state_transformer)

                    e_loss = ntd_loss(online_model=self.online_network,
                                      target_model=self.target_network,
                                      states=e_states,
                                      actions=e_actions,
                                      next_states=e_next_states,
                                      rewards=e_rewards,
                                      dones=e_dones,
                                      gamma=0.99,
                                      n=1)
                    e_weights = torch.FloatTensor(e_weights).to(self.device)
                    e_loss = e_loss * e_weights
                    e_priorities = e_loss + 1e-5
                    e_priorities = e_priorities.detach().cpu().numpy()
                    self.expert_buffer.update_priorities(
                        e_priorities, e_indices)
                    e_loss = e_loss.mean()

                    if self.config['n_steps'] > 1:
                        e_nstep_loss = ntd_loss(
                            online_model=self.online_network,
                            target_model=self.target_network,
                            states=e_states,
                            actions=e_actions,
                            next_states=e_nth_states,
                            rewards=e_discounted_rewards,
                            dones=e_dones,
                            gamma=0.99,
                            n=e_ns)
                        e_nstep_loss = e_nstep_loss.mean()
                        e_loss += e_nstep_loss

                    q_values = self.online_network(e_states)
                    e_loss += torch.mean(margin_loss(q_values, e_actions))

                # Finally, add this sucker to the loss if we do have expert samples.
                if len(self.buffer) > self.config["batch_size"]:
                    if self.config['batch_size'] > 0 and self.config[
                            'expert_batch_size'] > 0:
                        loss = loss * self.online_coef + e_loss * self.expert_coef
                    elif self.config['batch_size'] > 0:
                        loss = loss
                    elif self.config['expert_batch_size'] > 0:
                        loss = e_loss

                    # Take the step of updating online network parameters
                    # based on this batch loss.
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    # End of training step actions
                    epoch_loss.append(loss.detach().cpu().numpy())

                # End of every step actions
                if self.step % self.config['update_interval'] == 0:
                    self.target_network.load_state_dict(
                        self.online_network.state_dict())
                self.step += 1
                score += current_trans.reward

                if current_trans.done:
                    if score > self.best_episode:
                        self.best_episode = score

                    self.episodic_reward.append(score)
                    score = 0
                    self.state = env.reset()
                    self.prime_buffer(env)

                    wandb.log({"episodic_reward": self.episodic_reward[-1]})

            # End of batch actions
            self.loss.append(np.mean(epoch_loss))
            print("Epoch {0}, Score {1:6.4f}, Loss {2:6.4f}, Time {3:6.4f}".
                  format(epoch, score, self.loss[-1],
                         time.time() - start_time))

            wandb.log({
                "time": time.time() - start_time,
                "loss": self.loss[-1],
                "epsilon": epsilon,
                "beta": beta
            })

        wandb.log({"best_episode": self.best_episode})
コード例 #11
0
ファイル: main.py プロジェクト: qq345736500/wh
tf.logging.info("input dimension:{}".format(X_embedding.get_shape()))

if args.model_type == 'capsule-A':    
    poses, activations = capsule_model_A(X_embedding, args.num_classes)    
if args.model_type == 'capsule-B':    
    poses, activations = capsule_model_B(X_embedding, args.num_classes)    
if args.model_type == 'CNN':    
    poses, activations = baseline_model_cnn(X_embedding, args.num_classes)
if args.model_type == 'KIMCNN':    
    poses, activations = baseline_model_kimcnn(X_embedding, args.max_sent, args.num_classes)   
    
if args.loss_type == 'spread_loss':
    loss = spread_loss(y, activations, margin)
if args.loss_type == 'margin_loss':    
    loss = margin_loss(y, activations)
if args.loss_type == 'cross_entropy':
    loss = cross_entropy(y, activations)

y_pred = tf.argmax(activations, axis=1, name="y_proba")    
correct = tf.equal(tf.argmax(y, axis=1), y_pred, name="correct")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)   
training_op = optimizer.minimize(loss, name="training_op")
gradients, variables = zip(*optimizer.compute_gradients(loss))

grad_check = [tf.check_numerics(g, message='Gradient NaN Found!')
              for g in gradients if g is not None] + [tf.check_numerics(loss, message='Loss NaN Found')]
with tf.control_dependencies(grad_check):
    training_op = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)      
コード例 #12
0
ファイル: simple_model.py プロジェクト: zx-/CapsNet
    (x_train, y_train), (x_test,
                         y_test) = mnist.load_data(f'{os.getcwd()}/data/mnist')
    x_train = x_train / 255.0

    return (x_train, y_train), (x_test, y_test)


if __name__ == '__main__':

    run_num = 2

    (x_train, y_train), (x_test, y_test) = prepare_data()
    p_x, p_target = create_input_placeholders()

    output = simple_caps_net(p_x)
    loss_fn = loss.margin_loss(output, tf.one_hot(p_target, NUM_CLASSES))

    opt = tf.train.AdamOptimizer(0.03)
    training_operation = slim.learning.create_train_op(
        loss_fn, opt, summarize_gradients=True)

    # summaries
    tf.summary.scalar('margin_loss', loss_fn)
    merged = tf.summary.merge_all()
    print_model_summary()

    session = tf.Session()

    writer = tf.summary.FileWriter(f'{os.getcwd()}/tmp/log/{run_num}',
                                   session.graph)
コード例 #13
0
def train(args, model, optimizer, train_dloader, val_dloader, summary):
    # ----------- Start training ----------------------------------
    best_eval_acc = 0
    start_epoch = 0
    iter_calc = 0
    for epoch in range(start_epoch, start_epoch + args['max_epoches']):
        # ------- Train ---------------------------
        model.train()
        titer = tqdm(train_dloader, ncols=60)
        results = defaultdict(list)
        for res in titer:
            iter_calc += 1
            res_cuda = [r.cuda() for r in res[:-1]]
            b_embs, b_toks, b_msks, b_s_msks, b_c_msks, b_imgs, b_im_msks, b_poses, b_pose_msks, b_i3d_rgb, b_face, b_face_msks, b_bbox_meta, b_dep_root_msk, b_gt_gender, b_gt_positions, b_gt_mats, b_gt_vmat, b_gt_vmat_msk = res_cuda
            B, sN, L = b_s_msks.shape
            cN, hN = b_imgs.size(1), b_imgs.size(2)
            pred_ground, pred_reid = model(b_embs, b_toks, b_msks, b_s_msks,
                                           b_c_msks, b_dep_root_msk, b_imgs,
                                           b_im_msks, b_poses, b_pose_msks,
                                           b_i3d_rgb, b_face, b_face_msks,
                                           b_bbox_meta)  # B x sN x cN*hN
            # ---------- Get Loss ------------------------------------------
            # ---- Image wise ------------------------
            ground_clip_msk = b_c_msks.view(B, sN, cN,
                                            1).repeat(1, 1, 1,
                                                      hN).view(B, sN, cN * hN)
            ground_msk = ground_clip_msk * b_im_msks.view(
                B, 1, cN * hN)  # B x sN x cN*hN
            s_msks = b_s_msks.sum(-1)  # [B, sN]
            s_num = s_msks.sum(-1)  # [B] number of sN in each batch

            gt_position = b_gt_positions.view(B, sN, 1, hN).repeat(
                1, 1, cN, 1).view(B, sN, cN * hN)
            gt_position = ground_msk * gt_position  # [B x sN x cN*hN], 1 if positive
            pos_msk = 1 - gt_position  # 0 if positive
            pos_scores, _ = (pred_ground - pos_msk * 20).max(-1)  # [B x sN]
            neg_msk = (1 - ground_msk) + gt_position  # 0 if in clip negative
            neg_scores, _ = (pred_ground - neg_msk * 20).max(-1)  # [B x sN]
            if args['loss'] == 'margin':
                loss1 = loss.margin_loss(pos_scores, neg_scores, s_msks)
                # other way
                pos_scores2, _ = (pred_ground - pos_msk * 20).max(
                    1)  # [B x cN*hN]
                neg_msk2 = (1 - pos_msk * s_msks.unsqueeze(-1))
                neg_scores2, _ = (pred_ground - neg_msk2 * 20).max(
                    1)  # [B x cN*hN]
                loss2 = loss.margin_loss(pos_scores2, neg_scores2,
                                         gt_position.sum(1))
                total_loss = loss1 + loss2
                results['ground_loss'].append(total_loss.cpu().item())
            elif args['loss'] == 'ce':
                total_loss = loss.ce_loss(pred_ground, gt_position, ground_msk,
                                          s_msks)
                results['ground_loss'].append(total_loss.cpu().item())

            if args['char_reid']:
                c_loss = loss.reid_loss(pred_reid, b_gt_mats, s_msks)
                total_loss += 0.1 * c_loss
                results['reid_loss'].append(c_loss.cpu().item())

                v_loss = loss.vreid_loss(torch.sigmoid(model.vid_mat),
                                         b_gt_vmat, b_gt_vmat_msk)
                total_loss += 0.1 * v_loss
                results['vreid_loss'].append(v_loss.cpu().item())
            if args['use_gender']:
                gender_logit = model.gender_result
                gender_gt = b_gt_gender[:, :, 0]
                gender_msk = b_gt_gender[:, :, 1]
                g_loss = loss.gender_loss(gender_logit, gender_gt, gender_msk)
                total_loss += g_loss
                results['gender_loss'].append(g_loss.cpu().item())

            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # ---- Calculate acc --------------------------------------------
            _, preds = torch.stack([neg_scores, pos_scores], -1).max(-1)
            preds = preds.cpu()  # [B x sN]
            for bidx, b_pred in enumerate(preds):
                for sidx in range(len(res[-1][bidx])):
                    if b_pred[sidx] == 1:  #pos is larger
                        results['ground_acc'].append(1)
                    else:
                        results['ground_acc'].append(0)
            if args['use_gender']:
                g_preds = (abs(gender_logit - gender_gt)).cpu()
                gender_msk_cpu = gender_msk.cpu()
                for bidx, g_pred in enumerate(g_preds):
                    for sidx, gp in enumerate(g_pred):
                        if gender_msk_cpu[bidx, sidx] == 1:
                            if g_pred[sidx] < 0.5:  #pos is larger
                                results['gender_acc'].append(1)
                            else:
                                results['gender_acc'].append(0)
        #lr_scheduler.step()
        # ---- Print loss, acc -------------------------------------------------
        print_str = ''
        for key in results:
            res_avg = sum(results[key]) / len(results[key])
            summary.add_scalar('train/' + key, res_avg, epoch)
            print_str += key + ': %.3f ' % res_avg
        print(args['save_dir'])
        print(print_str)
        result, eval_scores = evaluate(args, model, val_dloader, epoch)
        # Save result
        if best_eval_acc < eval_scores['grounding']:
            best_eval_acc = eval_scores['grounding']
            torch.save(
                {
                    'epoch': epoch + 1,
                    'iter_calc': iter_calc,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join('logdir', args['save_dir'],
                                'best_ckpt.pth.tar'))
            with open(
                    os.path.join('logdir', args['save_dir'],
                                 'mvad_output.pkl'), 'wb') as f:
                pickle.dump(result, f)