def train(self, epoch):
        self.tcn.switch_mode('train')

        epoch_loss = 0.
        cum_loss = 0.
        cum_ratio_loss = 0.
        n = 0
        for idx in poem(range(len(self.train_set)), "train " + str(epoch)):
            X, labels, perspectives, paths = self.train_set[idx]
            if X is None:
                continue
            n += 1
            X = X.to(self.devices[0])
            labels = labels.to(self.devices[1])
            perspectives = perspectives.to(self.devices[1])

            self.optimizer.zero_grad()
            y = self.tcn(X)

            assert not Tools.contains_nan(y)

            loss = Loss.triplet_semihard_loss(y,
                                              labels,
                                              perspectives,
                                              margin=Config.TCN_MARGIN,
                                              device=self.devices[1])
            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()

            cum_loss += Loss.embedding_accuracy(y,
                                                labels,
                                                perspectives,
                                                device=self.devices[1]).item()
            cum_ratio_loss += Loss.embedding_accuracy_ratio(
                y, labels, perspectives)

        Tools.pyout('Train Epoch: ' + '{} [{}/{} ({:.0f}%)]\tAccuracy: '
                    '{:.6f}\tRatio: {:.6f}\tLoss: {:.6f}'.format(
                        epoch, idx, len(self.train_set), 100. * idx /
                        len(self.train_set), cum_loss / n, cum_ratio_loss /
                        n, epoch_loss / (n)))
        Tools.log('Train Epoch: ' +
                  '{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                      epoch, n, len(self.train_set), 100., epoch_loss / (n)))
    def test(self, epoch):
        self.tcn.switch_mode('eval')
        cum_loss = 0.
        cum_ratio_loss = 0.
        n = 0
        for idx in poem(range(len(self.val_set)), "eval " + str(epoch)):
            # if idx == 5:
            #     break
            with torch.no_grad():
                batch = self.val_set[idx]
                X, labels, perspectives = batch[0], batch[1], batch[2]

                if X is None:
                    continue
                n += 1

                X = X.to(self.devices[0])
                labels = labels.to(self.devices[1])
                perspectives = perspectives.to(self.devices[1])

                y = self.tcn(X)

                assert not Tools.contains_nan(y)

                cum_loss += Loss.embedding_accuracy(
                    y, labels, perspectives, device=self.devices[1]).item()
                cum_ratio_loss += Loss.embedding_accuracy_ratio(
                    y, labels, perspectives)

        Tools.pyout('Test Epoch: ' + (
            '{} [{}/{} ({:.0f}%)]\tAccuracy: '
            '{:.6f}\tRatio: {:.6f}').format(epoch, n, len(self.val_set), 100. *
                                            idx / len(self.val_set), cum_loss /
                                            (n), cum_ratio_loss / (n)))
        Tools.log('Test Epoch: ' +
                  ('{} [{}/{} ({:.0f}%)]\tAccuracy: '
                   '{:.6f}\tRatio: {:.6f}'
                   ).format(epoch, n, len(self.val_set), 100., cum_loss /
                            (n), cum_ratio_loss / (n)))
        if cum_ratio_loss / (n) < self.best_ratio:
            self.best_ratio = cum_ratio_loss / (n)
            self.save_state_dict(self.save_loc)
            self.last_improvement = epoch
    def __getitem__(self, idx):
        # if val set, use same random seed
        if not self.augment:
            random.seed(self.seeds[idx])
        trial_folder = self.trial_names[idx]
        if 'fake' in trial_folder:
            return (None, None, None, trial_folder)

        X = np.zeros((self.batch_size,) + self.input_size)
        labels = np.zeros((self.batch_size))
        perspectives = np.zeros((self.batch_size,))
        paths = []
        frames_used = [-float("inf")]
        n = 0
        fails = 0
        while n < self.batch_size // 2:
            # sample two perspectives
            samples_pos = random.sample(Tools.list_dirs(trial_folder), 2)

            # sample anchor frame
            a_val, a_pth, a_idx = self._sample_frame(
                samples_pos[0], frames_used)
            # sample positive frame
            p_val, p_pth, p_idx = self._sample_frame(
                samples_pos[1], frames_used, anchor_idx=a_idx)

            # deal with failing to find a valid pair
            if not a_val or not p_val:
                fails += 1
                if fails > self.batch_size:  # give up
                    break
            else:
                # add anchor frame to batch
                paths.append(a_pth)
                img_a = Transformer.transform(
                    cv2.imread(a_pth), BGR=False)
                X[n * 2, :, :,
                    :] = self.transform(Image.fromarray(img_a)).numpy()
                labels[n * 2] = n
                perspectives[n * 2] = self.pos2num[a_pth.split('/')[-2]]

                # add positive frame to batch
                paths.append(p_pth)
                img_p = Transformer.transform(
                    cv2.imread(p_pth), BGR=False)
                X[n * 2 + 1, :, :,
                    :] = self.transform(Image.fromarray(img_p)).numpy()
                labels[n * 2 + 1] = n
                perspectives[n * 2 + 1] = self.pos2num[p_pth.split('/')[-2]]

                n += 1

        # if batch is not entirely full, cut off zero padding
        X = X[:n * 2, :, :, :]
        labels = labels[:n * 2]
        perspectives = perspectives[:n * 2]
        if X.shape[0] == 0:
            return (None, None, None, trial_folder)
        else:
            X = torch.FloatTensor(X)
            labels = torch.FloatTensor(labels)
            perspectives = torch.FloatTensor(perspectives)

            assert not Tools.contains_nan(X)
            assert not Tools.contains_nan(labels)
            assert not Tools.contains_nan(perspectives)

            return (X, labels, perspectives, paths)