Пример #1
0
    def get_masked_lm_loss(
        logit_blob,
        masked_lm_positions,
        masked_lm_labels,
        label_weights,
        max_prediction_per_seq,
    ):
        # gather valid position indices
        logit_blob = flow.gather(
            logit_blob,
            index=masked_lm_positions.unsqueeze(2).repeat(
                1, 1, args.vocab_size),
            dim=1,
        )

        logit_blob = flow.reshape(logit_blob, [-1, args.vocab_size])
        label_id_blob = flow.reshape(masked_lm_labels, [-1])

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        pre_example_loss = mlm_criterion(logit_blob, label_id_blob)
        pre_example_loss = flow.reshape(pre_example_loss,
                                        [-1, max_prediction_per_seq])
        numerator = flow.sum(pre_example_loss * label_weights)
        denominator = flow.sum(label_weights) + 1e-5
        loss = numerator / denominator
        return loss
Пример #2
0
def get_masked_lm_loss(
    logit_blob,
    masked_lm_positions,
    masked_lm_labels,
    label_weights,
    max_prediction_per_seq=20,
):
    # gather valid position indices
    logit_blob = flow.gather(
        logit_blob,
        index=masked_lm_positions.unsqueeze(2).repeat(1, 1, 30522),
        dim=1,
    )
    logit_blob = flow.reshape(logit_blob, [-1, 30522])
    label_id_blob = flow.reshape(masked_lm_labels, [-1])

    # The `positions` tensor might be zero-padded (if the sequence is too
    # short to have the maximum number of predictions). The `label_weights`
    # tensor has a value of 1.0 for every real prediction and 0.0 for the
    # padding predictions.
    pre_example_loss = nn.CrossEntropyLoss(reduction="none")(logit_blob,
                                                             label_id_blob)
    pre_example_loss = flow.reshape(pre_example_loss,
                                    [-1, max_prediction_per_seq])
    sum_label_weight = flow.sum(label_weights, dim=-1)
    sum_label_weight = sum_label_weight / label_weights.shape[0]
    numerator = flow.sum(pre_example_loss * label_weights)
    denominator = flow.sum(label_weights) + 1e-5
    loss = numerator / denominator
    return logit_blob, loss
Пример #3
0
def _test_sum_impl(test_case, device):
    input = flow.tensor(
        np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device)
    )
    of_out = flow.sum(input, dim=0)
    np_out = np.sum(input.numpy(), axis=0)
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    input = flow.tensor(
        np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device)
    )
    of_out = flow.sum(input, dim=0)
    np_out = np.sum(input.numpy(), axis=0)
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    input = flow.tensor(
        np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device)
    )
    of_out = flow.sum(input, dim=1)
    of_out2 = input.sum(dim=1)
    np_out = np.sum(input.numpy(), axis=1)
    test_case.assertTrue(np.allclose(of_out2.numpy(), of_out.numpy(), 1e-05, 1e-05))
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    input = flow.tensor(
        np.random.randn(4, 5, 6),
        dtype=flow.float32,
        device=flow.device(device),
        requires_grad=True,
    )
    of_out = flow.sum(input, dim=(2, 1))
    np_out = np.sum(input.numpy(), axis=(2, 1))
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    of_out = of_out.sum()
    of_out.backward()
    np_grad = np.ones((4, 5, 6))
    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))
Пример #4
0
    def forward(self, logits, target, mask=None):
        """LabelSmoothing Function with Mask

        Args:
            logits ([tensor]): logits with shape [batch, length, vocab_size]
            target ([tensor]): target with shape [batch, length]
            mask ([tensor], optional): mask tensor (bool) with shape [batch, length]
        """
        assert logits.dim() == 3 and logits.size(-1) == self.size

        pad_mask = target == self.padding_idx
        if mask is not None:
            mask = (pad_mask.int() + mask.int()) > 0
        else:
            mask = pad_mask

        logits = logits.reshape(-1, self.size)
        with flow.no_grad():
            confidence = logits.clone()
            confidence.fill_(self.smoothing / (self.size - 1))
            confidence = flow.scatter(confidence, 1,
                                      target.reshape(-1).unsqueeze(1),
                                      1 - self.smoothing)

        logsoftmax = nn.LogSoftmax(dim=-1)
        KLdiv = nn.KLDivLoss(reduction="none", log_target=False)
        loss = flow.sum(KLdiv(logsoftmax(logits), confidence), dim=-1)

        total = flow.sum(mask == 0)
        denom = total if self.normalize_length else logits.size(0)
        loss = flow.masked_fill(loss, mask.reshape(-1), 0.0)
        loss = flow.sum(loss) / denom

        return loss
Пример #5
0
    def get_masked_lm_loss(
        logit, masked_lm_labels, label_weights, max_predictions_per_seq,
    ):

        label_id = flow.reshape(masked_lm_labels, [-1])

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        pre_example_loss = mlm_criterion(logit, label_id)
        pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq])
        numerator = flow.sum(pre_example_loss * label_weights)
        denominator = flow.sum(label_weights) + 1e-5
        loss = numerator / denominator
        return loss
Пример #6
0
def compare_loss(device_type, dim, reduction, cls, data_generator):
    x, y, x1, y1 = data_generator(dim, device_type, *get_sbp(device_type))
    reduce_loss_func = cls(reduction=reduction).to(device_type)
    none_loss_func = cls(reduction="none").to(device_type)

    loss_mean = reduce_loss_func(x, y)
    loss_none = (flow.mean(none_loss_func(x1, y1))
                 if reduction == "mean" else flow.sum(none_loss_func(x1, y1)))

    loss_mean.backward()
    loss_none.backward()

    assert np.allclose(
        loss_none.to_local().numpy(),
        loss_mean.to_local().numpy(),
        rtol=1e-05,
        atol=1e-05,
    )
    assert np.allclose(
        loss_none.numpy(),
        loss_mean.numpy(),
        rtol=1e-05,
        atol=1e-05,
    )
    assert np.allclose(
        x.grad.to_local().numpy(),
        x1.grad.to_local().numpy(),
        rtol=1e-05,
        atol=1e-05,
    )
Пример #7
0
 def train_one_iter(grad):
     grad_tensor = flow.tensor(
         grad, requires_grad=False, device=flow.device(device)
     )
     loss = flow.sum(x * grad_tensor)
     loss.backward()
     adagrad.step()
     adagrad.zero_grad()
Пример #8
0
def _test_sum_impl(test_case, device, data_type):
    if device == "cpu" and data_type == flow.float16:
        return
    input = flow.tensor(np.random.randn(2, 3) - 0.5,
                        dtype=data_type,
                        device=flow.device(device))
    of_out = flow.sum(input, dim=0)
    np_out = np.sum(input.numpy(), axis=0)
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    input = flow.tensor(np.random.randn(2, 3),
                        dtype=data_type,
                        device=flow.device(device))
    of_out = flow.sum(input, dim=0)
    np_out = np.sum(input.numpy(), axis=0)
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    input = flow.tensor(np.random.randn(2, 3),
                        dtype=data_type,
                        device=flow.device(device))
    of_out = flow.sum(input, dim=1)
    of_out2 = input.sum(dim=1)
    np_out = np.sum(input.numpy(), axis=1)
    test_case.assertTrue(
        np.allclose(of_out2.numpy(), of_out.numpy(), 1e-05, 1e-05))
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    input = flow.tensor(
        np.random.randn(4, 5, 6) - 0.5,
        dtype=data_type,
        device=flow.device(device),
        requires_grad=True,
    )
    of_out = flow.sum(input, dim=(2, 1))
    np_out = np.sum(input.numpy(), axis=(2, 1))
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
    of_out = of_out.sum()
    of_out.backward()
    np_grad = np.ones((4, 5, 6))
    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05,
                                     1e-05))

    # For 0-dim tensor test
    input = flow.tensor(1.0)
    of_out = input.sum()
    test_case.assertTrue(
        np.allclose(input.numpy(), of_out.numpy(), 1e-05, 1e-05))
Пример #9
0
 def forward(self, input, target):
     prob, out = self._op(input,
                          target,
                          depth=input.shape[len(input.shape) - 1])
     if self.reduction == "mean":
         return flow.mean(out)
     elif self.reduction == "sum":
         return flow.sum(out)
     else:
         return out
Пример #10
0
 def train_one_iter(grad):
     grad_tensor = flow.tensor(
         grad,
         dtype=flow.float32,
         requires_grad=False,
         device=flow.device(device),
     )
     loss = flow.sum(x * grad_tensor)
     loss.backward()
     rmsprop.step()
     rmsprop.zero_grad()
Пример #11
0
 def forward(self, input_tensor):
     reduce_sum = flow.sum(input_tensor,
                           dim=self.axis,
                           keepdims=self.keepdims)
     reduce_count = 1
     if len(self.axes) == 0:
         for dim in input_tensor.shape:
             reduce_count *= dim
     else:
         for i in self.axes:
             reduce_count *= input_tensor.shape[i]
     return flow.mul(reduce_sum, 1.0 / reduce_count)
Пример #12
0
    def inference(self, memory, memory_mask):

        if self.apply_look_ahead:
            memory = F.pad(memory, pad=(0, 0, 0, self.lookahead_steps), value=0.0)
            memory = memory.transpose(1, 2)
            memory = self.lookahead_conv(memory)
            memory = memory.transpose(1, 2)

        logits = self.output_layer(memory)
        memory_length = flow.sum(memory_mask.squeeze(1), dim=-1)
        logsoftmax = nn.LogSoftmax(dim=-1)
        return logsoftmax(logits), memory_length
Пример #13
0
 def forward(self, dense_fields, wide_sparse_fields,
             deep_sparse_fields) -> flow.Tensor:
     wide_embedding = self.wide_embedding(wide_sparse_fields)
     wide_embedding = wide_embedding.view(
         -1, wide_embedding.shape[-1] * wide_embedding.shape[-2])
     wide_scores = flow.sum(wide_embedding, dim=1, keepdim=True)
     deep_embedding = self.deep_embedding(deep_sparse_fields)
     deep_embedding = deep_embedding.view(
         -1, deep_embedding.shape[-1] * deep_embedding.shape[-2])
     deep_features = flow.cat([deep_embedding, dense_fields], dim=1)
     deep_features = self.linear_layers(deep_features)
     deep_scores = self.deep_scores(deep_features)
     return self.sigmoid(wide_scores + deep_scores)
Пример #14
0
        def train_one_iter(grad):
            grad_tensor = flow.tensor(
                grad,
                dtype=flow.float32,
                requires_grad=False,
                device=flow.device(device),
            )

            loss = flow.sum(x * grad_tensor)
            loss.backward()
            if clip_grad_max_norm != -1:
                lamb.clip_grad()
            lamb.step()
            lamb.zero_grad()
Пример #15
0
    def forward(self, inputs) -> flow.Tensor:
        multi_embedded_x = self.embedding_layer(inputs)
        embedded_x = multi_embedded_x[:, :, 0:self.embedding_vec_size]
        lr_embedded_x = multi_embedded_x[:, :, -1]

        # FM
        lr_out = flow.sum(lr_embedded_x, dim=1, keepdim=True)
        dot_sum = interaction(embedded_x)
        fm_pred = lr_out + dot_sum

        # DNN
        dnn_pred = self.dnn_layer(embedded_x.flatten(start_dim=1))

        return fm_pred + dnn_pred
Пример #16
0
    def test_sum(test_case):
        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
        of_out = flow.sum(input, dim=0)
        np_out = np.sum(input.numpy(), axis=0)
        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))

        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
        of_out = flow.sum(input, dim=0)
        np_out = np.sum(input.numpy(), axis=0)
        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))

        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
        of_out = flow.sum(input, dim=1)
        of_out2 = input.sum(dim=1)
        np_out = np.sum(input.numpy(), axis=1)
        test_case.assertTrue(
            np.allclose(of_out2.numpy(), of_out.numpy(), 1e-4, 1e-4))
        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))

        input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32)
        of_out = flow.sum(input, dim=(2, 1))
        np_out = np.sum(input.numpy(), axis=(2, 1))
        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
Пример #17
0
    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = flow.ones(y.size()).to(self.device)

        dydx = flow.autograd.grad(outputs=y,
                                  inputs=x,
                                  out_grads=weight,
                                  retain_graph=True,
                                  create_graph=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = flow.sqrt(flow.sum(dydx**2, dim=1))

        return flow.mean((dydx_l2norm - 1)**2)
Пример #18
0
    def ts_forward(
        self, inputs, inputs_length, targets, targets_length, return_loss=True
    ):
        memory, memory_mask = self.encoder(inputs, inputs_length)
        logits = self.assistor.output_layer(memory)

        if return_loss:
            memory_length = flow.sum(memory_mask.squeeze(1), dim=-1)
            targets_out = targets[:, 1:].clone()
            loss = self.assistor.compute_loss(
                logits, memory_length, targets_out, targets_length
            )
            return loss, logits, memory_mask

        return logits, memory_mask
Пример #19
0
    def forward(self, inputs, targets):

        enc_inputs = inputs["inputs"]
        enc_mask = inputs["mask"]

        truth = targets["targets"]
        truth_length = targets["targets_length"]

        enc_inputs, enc_mask = self.frontend(enc_inputs, enc_mask)
        memory, memory_mask, _ = self.encoder(enc_inputs, enc_mask)

        memory_length = flow.sum(memory_mask, dim=-1)
        loss = self.assistor(
            memory, memory_length, truth[:, 1:-1], truth_length.add(-1)
        )
        return loss, None
Пример #20
0
 def forward(self, input, target):
     input_shape_len = len(input.shape)
     if input_shape_len == 4:
         b, c, h, w = input.shape[0], input.shape[1], input.shape[
             2], input.shape[3]
         input = flow.tmp.transpose(input, (0, 2, 3, 1))
         input = flow.tmp.reshape(input, shape=[-1, input.shape[3]])
         target = flow.tmp.flatten(target)
     prob, out = self._op(input,
                          target,
                          depth=input.shape[len(input.shape) - 1])
     if self.reduction == "mean":
         return flow.mean(out)
     elif self.reduction == "sum":
         return flow.sum(out)
     else:
         if input_shape_len == 4:
             out = flow.tmp.reshape(out, (b, h, w))
         return out
Пример #21
0
    def compute_loss(self, est, egs):
        # spks x n x S
        ests = est
        # spks x n x S
        refs = egs["ref"]
        num_spks = len(refs)

        def sisnr_loss(permute):
            # for one permute
            return sum(
                [self.sisnr(ests[s], refs[t])
                 for s, t in enumerate(permute)]) / len(permute)

        # P x N
        N = egs["mix"].size(0)
        sisnr_mat = flow.stack(
            [sisnr_loss(p) for p in permutations(range(num_spks))])
        max_perutt, _ = flow.max(sisnr_mat, dim=0)
        # si-snr
        return -flow.sum(max_perutt) / N
Пример #22
0
    def forward(self, input, target, weight=None):
        assert (input.shape == target.shape
                ), "The Input shape must be the same as Target shape"

        _cross_entropy_loss = flow.negative(target * flow.log(input) +
                                            (1 - target) * flow.log(1 - input))

        if weight is not None:
            assert (weight.shape == input.shape
                    ), "The weight shape must be the same as Input shape"
            _weighted_loss = weight * _cross_entropy_loss
        else:
            _weighted_loss = _cross_entropy_loss

        if self.reduction == "mean":
            return flow.mean(_weighted_loss)
        elif self.reduction == "sum":
            return flow.sum(_weighted_loss)
        else:
            return _weighted_loss
Пример #23
0
 def forward(self, dense_fields, wide_sparse_fields,
             deep_sparse_fields) -> flow.Tensor:
     wide_sparse_fields = wide_sparse_fields.to_global(
         sbp=flow.sbp.broadcast)
     wide_embedding = self.wide_embedding(wide_sparse_fields)
     wide_embedding = wide_embedding.view(
         -1, wide_embedding.shape[-1] * wide_embedding.shape[-2])
     wide_scores = flow.sum(wide_embedding, dim=1, keepdim=True)
     wide_scores = wide_scores.to_global(sbp=flow.sbp.split(0),
                                         grad_sbp=flow.sbp.broadcast)
     deep_sparse_fields = deep_sparse_fields.to_global(
         sbp=flow.sbp.broadcast)
     deep_embedding = self.deep_embedding(deep_sparse_fields)
     deep_embedding = deep_embedding.to_global(sbp=flow.sbp.split(0),
                                               grad_sbp=flow.sbp.split(2))
     deep_embedding = deep_embedding.view(
         -1, deep_embedding.shape[-1] * deep_embedding.shape[-2])
     deep_features = flow.cat([deep_embedding, dense_fields], dim=1)
     deep_features = self.linear_layers(deep_features)
     deep_scores = self.deep_scores(deep_features)
     return self.sigmoid(wide_scores + deep_scores)
Пример #24
0
    def lm_rescoring(self, preds, pred_lens):
        # preds [beam_size, lens]
        # preds_len [beam_size]

        if self.lm.model_type == "transformer_lm":
            log_probs = self.lm.predict(preds, last_frame=False)
        else:
            log_probs = []
            hidden = None
            for t in range(preds.size(1)):
                log_prob, hidden = self.lm.predict(preds[:, t].unsqueeze(-1),
                                                   hidden)
                log_probs.append(log_prob)

            log_probs = flow.cat(log_probs, dim=1)

        rescores = []
        max_length = log_probs.size(1)
        vocab_size = log_probs.size(-1)

        for b in range(preds.size(0)):
            base_index = flow.arange(max_length, device=preds.device)
            bias_index = preds[b].reshape(-1)

            index = base_index * vocab_size + bias_index
            score = flow.index_select(log_probs[b].reshape(-1),
                                      dim=-1,
                                      index=index)

            label_len = min(int(pred_lens[b]), score.size(0))
            score[label_len - 1:] = 0
            rescores.append(flow.sum(score) / label_len)

        rescores = flow.tensor(rescores, dtype=flow.float32)
        _, indices = flow.sort(rescores, dim=-1, descending=True)

        sorted_preds = preds[indices]
        sorted_length = pred_lens[indices]

        return sorted_preds, sorted_length
Пример #25
0
    def forward(self, input, target):
        assert len(input.shape) == 2 or len(input.shape) == 4
        input = flow.negative(input)
        if len(input.shape) == 2:
            res = self.nllloss_1d(input, target)
        elif len(input.shape) == 4:
            b, c, h, w = input.shape[0], input.shape[1], input.shape[
                2], input.shape[3]
            input = flow.tmp.transpose(input, (0, 2, 3, 1))
            input = flow.tmp.reshape(input, shape=[-1, input.shape[3]])
            target = flow.tmp.flatten(target)
            res = self.nllloss_1d(input, target)
            res = flow.tmp.reshape(res, (b, h, w))

        else:
            raise NotImplemented

        if self.reduction == "none":
            return res
        elif self.reduction == "sum":
            return flow.sum(res)
        else:
            return flow.mean(res)
Пример #26
0
    def sisnr(self, x, s, eps=1e-8):
        """
        Arguments:
        x: separated signal, N x S tensor
        s: reference signal, N x S tensor
        Return:
        sisnr: N tensor
        """
        def l2norm(mat, keepdim=False):
            return flow.linalg.norm(mat, dim=-1, keepdim=keepdim)

        if x.shape != s.shape:
            raise RuntimeError(
                "Dimention mismatch when calculate si-snr, {} vs {}".format(
                    x.shape, s.shape))
        x_zm = x - flow.mean(x, dim=-1, keepdim=True)
        s_zm = s - flow.mean(s, dim=-1, keepdim=True)
        t = (flow.sum(x_zm * s_zm, dim=-1, keepdim=True) * s_zm /
             (l2norm(s_zm, keepdim=True)**2 + eps))

        res = 20 * flow.log(eps + l2norm(t) /
                            (l2norm(x_zm - t) + eps)) / 2.3025851

        return res
Пример #27
0
def _sum(self, dim=[], keepdim=False):
    return flow.sum(self, dim, keepdim)
Пример #28
0
 def build(self, mask_tensor):
     loss = flow.sum(self.m(mask_tensor))
     loss.backward()
     return loss
Пример #29
0
                if count_fr > 0:
                    inp = flow.Tensor(sig_arr[0:count_fr]).to("cuda")

                    time_begin = time.time()
                    pout[count_fr_tot -
                         count_fr:count_fr_tot, :] = MOBILENET_net(inp)
                    time_end = time.time()
                    time_batch.append([inp.shape[0], time_end - time_begin])

                pred = flow.argmax(pout, dim=1)
                loss = cost(pout, lab.long())

                err = np.mean(pred.numpy() != lab.long().numpy())

                best_class = flow.argmax(flow.sum(pout, dim=0), dim=0)
                err_sum_snt = err_sum_snt + (best_class.numpy() !=
                                             lab[0].numpy())

                loss_sum = loss_sum + loss.detach()
                err_sum = err_sum + err

            err_tot_dev_snt = err_sum_snt / snt_te
            loss_tot_dev = loss_sum / snt_te
            err_tot_dev = err_sum / snt_te

        final = time.time()
        print(
            "epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f time=%f"
            % (
                epoch,
Пример #30
0
def train(opt):

    # Step 1: init BrainDQN
    model = DeepQNetwork()
    model.to("cuda")
    optimizer = flow.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = flow.nn.MSELoss()
    criterion.to("cuda")

    # Step 2: init Flappy Bird Game
    game_state = GameState()
    # Step 3: play game
    # image.shape = (288,512,3), reward: float, terminal: boolean
    image, reward, terminal = game_state.frame_step(0)
    # image.shape = (84, 84)
    image = pre_processing(
        image[: game_state.SCREENWIDTH, : int(game_state.BASEY)],
        opt.image_size,
        opt.image_size,
    )
    image = flow.Tensor(image, dtype=flow.float32)
    image = image.to("cuda")
    state = flow.cat(tuple(image for _ in range(4))).unsqueeze(0)

    replay_memory = []
    iter = 0
    # Step 4: run the game
    while iter < opt.num_iters:
        model.train()

        prediction = model(state)[0]
        # Exploration or exploitation
        epsilon = opt.final_epsilon + (
            (opt.num_iters - iter)
            * (opt.initial_epsilon - opt.final_epsilon)
            / opt.num_iters
        )
        u = random()
        random_action = u <= epsilon
        if random_action:
            print("Perform a random action")
            action = randint(0, 1)
        else:
            action = flow.argmax(prediction).numpy()[0]

        next_image, reward, terminal = game_state.frame_step(action)
        next_image = pre_processing(
            next_image[: game_state.SCREENWIDTH, : int(game_state.BASEY)],
            opt.image_size,
            opt.image_size,
        )
        next_image = flow.Tensor(next_image)
        next_image = next_image.to("cuda")
        next_state = flow.cat((state[0, 1:, :, :], next_image)).unsqueeze(0)

        replay_memory.append([state, action, reward, next_state, terminal])
        if len(replay_memory) > opt.replay_memory_size:
            del replay_memory[0]
        batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(
            *batch
        )

        state_batch = flow.cat(tuple(state for state in state_batch))
        action_batch = flow.Tensor(
            np.array(
                [[1, 0] if action == 0 else [0, 1] for action in action_batch],
                dtype=np.float32,
            )
        )
        reward_batch = flow.Tensor(np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = flow.cat(tuple(state for state in next_state_batch))

        state_batch = state_batch.to("cuda")
        action_batch = action_batch.to("cuda")
        reward_batch = reward_batch.to("cuda")
        next_state_batch = next_state_batch.to("cuda")
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = flow.cat(
            tuple(
                reward_batch[i]
                if terminal_batch[i]
                else reward_batch[i] + opt.gamma * flow.max(next_prediction_batch[i])
                for i in range(reward_batch.shape[0])
            )
        )

        q_value = flow.sum(current_prediction_batch * action_batch, dim=1)

        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        state = next_state
        iter += 1

        print(
            "Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
                iter + 1,
                opt.num_iters,
                action,
                loss.numpy(),
                epsilon,
                reward,
                flow.max(prediction).numpy()[0],
            )
        )

        if (iter + 1) % 100000 == 0:
            flow.save(
                model.state_dict(),
                os.path.join(opt.save_checkpoint_path, "epoch_%d" % (iter + 1)),
            )
    flow.save(
        model.state_dict(),
        os.path.join(opt.save_checkpoint_path, "epoch_%d" % (iter + 1)),
    )
    print("train success!")