Exemplo n.º 1
0
    def __init__(self):
        super(TrainingSMNBNM, self).__init__()
        self.name = "class"
        self.model = PseudoMultiTaskNet(no_svd=True)
        if CUDA:
            self.model.cuda(0)

        self.lr = 0.01

        self.epoch_args["optimizer"] = optim.SGD(
            self.model.parameters(),
            lr=10 * self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay)

        self.epoch_args["scheduler"] = StepLR(self.epoch_args["optimizer"],
                                              step_size=self.step_size,
                                              gamma=self.gamma)

        def loss(outputs, labels):
            likelihoods = torch.log((outputs[0] + outputs[1] + outputs[2]) / 3)
            disagreement = torch.sum(torch.abs(outputs[0] - outputs[1]))
            _, y_pred = torch.max(likelihoods.data, 1)

            return (F.nll_loss(likelihoods, labels) + 0.01 * disagreement,
                    y_pred)

        self.epoch_args["loss_fn"] = loss
        self.epoch_function = epoch_classification
        del self.epoch_args["aug_loss_fn"]
        del self.epoch_args["aug_optimizer"]
        del self.epoch_args["aug_scheduler"]
Exemplo n.º 2
0
    def __init__(self):
        self.name = "class_sp_inter_ba"
        self.model = PseudoMultiTaskNet()
        if CUDA:
            self.model.cuda(0)
        self.writer = SummaryWriter()
        self.epoch_function = epoch_mixed
        self.epoch_args = {}
        self.epochs = 20
        self.lr = 0.1
        self.momentum = 0.9
        self.weight_decay = 1e-5
        self.step_size = np.ceil(self.epochs / 3)
        self.gamma = 0.1

        def loss(outputs, labels):
            likelihoods = torch.log((outputs[0] + outputs[1] + outputs[2]) / 3)
            disagreement = torch.sum(torch.abs(outputs[0] - outputs[1]))
            sparsity = torch.norm(outputs[3], 1)
            _, y_pred = torch.max(likelihoods.data, 1)

            return (F.nll_loss(likelihoods, labels) + 0.01 * disagreement +
                    0.0001 * sparsity, y_pred)

        self.epoch_args["loss_fn"] = loss
        self.epoch_args["aug_loss_fn"] = lambda x, y: F.l1_loss(
            F.pad(x, (2, 2, 2, 2)), y)

        self.epoch_args["optimizer"] = optim.SGD(
            self.model.parameters(),
            lr=10 * self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay)

        self.epoch_args["scheduler"] = StepLR(self.epoch_args["optimizer"],
                                              step_size=self.step_size,
                                              gamma=self.gamma)

        self.aug_batch_size = 16
        self.batch_size = 16 * self.aug_batch_size
        self.aug_lr = 0.01

        self.epoch_args["aug_optimizer"] = optim.SGD(
            self.model.parameters(),
            lr=10 * self.aug_lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay)
        self.epoch_args["aug_scheduler"] = StepLR(
            self.epoch_args["aug_optimizer"],
            step_size=np.ceil(self.epochs / 2),
            gamma=self.gamma)

        # self.train_loader = load_augmented(self.batch_size,
        #                                    self.aug_batch_size)
        self.train_loader = load_mnist(self.batch_size, train=True)
        self.val_loader = load_mnist(self.batch_size, train=False)

        for key, value in self.__dict__.items():
            self.writer.add_text(f"config/{key}", str(value))
Exemplo n.º 3
0
def load_checkpoints(epoch_short, epoch_long):
    basepath = Path.home() / "Projects" / "pytorch-pseudomultitasknet" / "runs"
    checkpoint_paths = basepath.glob("**/checkpoints")
    models = OrderedDict({
        "sm_nb": None,
        "class": None,
        "class_sp": None,
        "class_inter": None,
        "class_inter_ba": None,
        "class_sp_inter": None,
        "class_sp_inter_ba": None,
        "class_sp_inter_ba_small": None,
        "class_sp_inter_ba_long": None,
        "class_sp_inter_ba_mult": None,
        "class_sp_inter_ba_mult_long": None
    })

    for path in checkpoint_paths:
        with open(path / ".." / "records.json") as f:
            obj = json.load(f)
        name = obj["name"]

        if "long" in name:
            epoch = epoch_long
        else:
            epoch = epoch_short

        if "mult" in name:
            model = PseudoMultiTaskNetMult()
        elif "small" in name:
            model = PseudoMultiTaskNet(small=True)
        else:
            model = PseudoMultiTaskNet()
        model.load_state_dict(torch.load(path / f"checkpoint_{epoch}.dat"))
        model.cuda(0)
        models[name] = model

    models = {k.replace("sp", "tr"): v for k, v in models.items()}

    return models