コード例 #1
0
def _make_enqueued_generator(generator,
                             workers=1,
                             use_multiprocessing=False,
                             max_queue_size=10,
                             shuffle=False):
    """Create a buffered queue of next elements of the generator."""
    is_sequence = isinstance(generator, data_utils.Sequence)
    enqueuer = None
    if workers > 0:
        if is_sequence:
            enqueuer = data_utils.OrderedEnqueuer(
                generator,
                use_multiprocessing=use_multiprocessing,
                shuffle=shuffle)
        else:
            enqueuer = data_utils.GeneratorEnqueuer(
                generator, use_multiprocessing=use_multiprocessing)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
    else:
        if is_sequence:
            output_generator = data_utils.iter_sequence_infinite(generator)
        else:
            output_generator = generator
    return output_generator, enqueuer
コード例 #2
0
ファイル: networks.py プロジェクト: DableUTeeF/whale_oneshot
    def train_generator(self, generator, num_worker):
        total_c_loss = 0.0
        total_accuracy = 0.0
        # optimizer = self._create_optimizer(self.matchNet, self.lr)
        # traindata = self.data.get_trainset(batch_size, 0, shuffle=True)
        traindata = generator
        train_enqueuer = data_utils.GeneratorEnqueuer(traindata)
        train_enqueuer.start(workers=num_worker,
                             max_queue_size=traindata.batch_size * 2)
        train_generator = train_enqueuer.get()
        total_train_batches = len(traindata)
        with tqdm.tqdm(total=total_train_batches) as pbar:
            for i in range(total_train_batches):
                # (x_support_set, y_support_set, x_target, y_target) = next(train_generator)
                # batch = []
                # for b in range(batch_size):
                batch = next(train_generator)
                x_support_set, y_support_set, x_target, y_target = self.numpy2tensor(
                    batch)
                x_support_set = x_support_set.permute(0, 1, 4, 2, 3)
                x_target = x_target.permute(0, 3, 1, 2)
                y_support_set = y_support_set.float()
                y_target = y_target.long()
                x_support_set = x_support_set.float()
                x_target = x_target.float()
                acc, c_loss = self.matchNet(x_support_set, y_support_set,
                                            x_target, y_target)

                # optimize process
                self.optimizer.zero_grad()
                c_loss.backward()
                self.optimizer.step()

                total_c_loss += c_loss.data[0]
                total_accuracy += acc.data[0]
                i += 1
                iter_out = f"loss: {total_c_loss / i:.{3}}, acc: {total_accuracy / i:.{3}}"
                pbar.set_description(iter_out)
                pbar.update(1)
                # self.total_train_iter+=1

            total_c_loss = total_c_loss / total_train_batches
            total_accuracy = total_accuracy / total_train_batches
            self.scheduler.step(total_c_loss)
            return total_c_loss, total_accuracy
コード例 #3
0
ファイル: networks.py プロジェクト: DableUTeeF/whale_oneshot
    def validate_generator(self, generator, num_worker):
        total_c_loss = 0.0
        total_accuracy = 0.0
        # optimizer = self._create_optimizer(self.matchNet, self.lr)
        # traindata = self.data.get_trainset(batch_size, 0, shuffle=True)
        valdata = generator
        val_enqueuer = data_utils.GeneratorEnqueuer(valdata)
        val_enqueuer.start(workers=num_worker,
                           max_queue_size=valdata.batch_size * 2)
        val_generator = val_enqueuer.get()
        total_val_batches = len(valdata)
        with tqdm.tqdm(total=total_val_batches) as pbar:
            with torch.no_grad():
                for i in range(total_val_batches):
                    # (x_support_set, y_support_set, x_target, y_target) = next(train_generator)
                    batch = next(val_generator)
                    x_support_set, y_support_set, x_target, y_target = self.numpy2tensor(
                        batch)
                    x_support_set = x_support_set.permute(0, 1, 4, 2, 3)
                    x_target = x_target.permute(0, 3, 1, 2)
                    y_support_set = y_support_set.float()
                    y_target = y_target.long()
                    x_support_set = x_support_set.float()
                    x_target = x_target.float()
                    # with torch.no_grad():
                    acc, c_loss = self.matchNet(x_support_set, y_support_set,
                                                x_target, y_target)
                    total_c_loss += c_loss.data[0]
                    total_accuracy += acc.data[0]
                    i += 1
                    iter_out = f"v_loss: {total_c_loss / i:.{3}}, v_acc: {total_accuracy / i:.{3}}"
                    pbar.set_description(iter_out)
                    pbar.update(1)
                    # self.total_train_iter+=1

            total_c_loss = total_c_loss / total_val_batches
            total_accuracy = total_accuracy / total_val_batches
            return total_c_loss, total_accuracy