예제 #1
0
def mobius_linear(
    input,
    weight,
    bias=None,
    hyperbolic_input=True,
    hyperbolic_bias=True,
    nonlin=None,
    c=1.0,
):
    if hyperbolic_input:
        output = pmath.mobius_matvec(weight, input, c=c)
        print('x')
        print(weight.grad)
        print('x')
    else:
        output = torch.nn.functional.linear(input, weight)
        # output = rm_linear(input, weight)
        output = pmath.expmap0(output, c=c)
    if bias is not None:
        if not hyperbolic_bias:
            bias = pmath.expmap0(bias, c=c)
        output = pmath.mobius_add(output, bias, c=c)
    if nonlin is not None:
        output = pmath.mobius_fn_apply(nonlin, output, c=c)
    output = pmath.project(output, c=c)
    return output
예제 #2
0
def mobius_gru_loop(
    input: torch.Tensor,
    h0: torch.Tensor,
    weight_ih: torch.Tensor,
    weight_hh: torch.Tensor,
    bias: torch.Tensor,
    c: torch.Tensor,
    batch_sizes=None,
    hyperbolic_input: bool = False,
    hyperbolic_hidden_state0: bool = False,
    nonlin=None,
):
    if not hyperbolic_hidden_state0:
        hx = pmath.expmap0(h0, c=c)
    else:
        hx = h0
    if not hyperbolic_input:
        input = pmath.expmap0(input, c=c)
    outs = []
    if batch_sizes is None:
        input_unbinded = input.unbind(0)
        for t in range(input.size(0)):
            hx = mobius_gru_cell(
                input=input_unbinded[t],
                hx=hx,
                weight_ih=weight_ih,
                weight_hh=weight_hh,
                bias=bias,
                nonlin=nonlin,
                c=c,
            )
            outs.append(hx)
        outs = torch.stack(outs)
        h_last = hx
    else:
        h_last = []
        T = len(batch_sizes) - 1
        for i, t in enumerate(range(batch_sizes.size(0))):
            ix, input = input[:batch_sizes[t]], input[batch_sizes[t]:]
            hx = mobius_gru_cell(
                input=ix,
                hx=hx,
                weight_ih=weight_ih,
                weight_hh=weight_hh,
                bias=bias,
                nonlin=nonlin,
                c=c,
            )
            outs.append(hx)
            if t < T:
                hx, ht = hx[:batch_sizes[t + 1]], hx[batch_sizes[t + 1]:]
                h_last.append(ht)
            else:
                h_last.append(hx)
        h_last.reverse()
        h_last = torch.cat(h_last)
        outs = torch.cat(outs)
    return outs, h_last
예제 #3
0
 def __init__(self, in_features, out_features, c=1.0):
     super().__init__()
     self.in_features = in_features
     self.out_features = out_features
     self.ball = ball = PoincareBall(c=c)
     self.sphere = sphere = geoopt.manifolds.Sphere()
     self.scale = torch.nn.Parameter(torch.zeros(out_features))
     point = torch.randn(out_features, in_features) / 4
     point = pmath.expmap0(point, c=c)
     tangent = torch.randn(out_features, in_features)
     self.point = ManifoldParameter(point, manifold=ball)
     with torch.no_grad():
         self.tangent = ManifoldParameter(tangent, manifold=sphere).proj_()
예제 #4
0
 def __init__(self,
              *args,
              hyperbolic_input=True,
              hyperbolic_bias=True,
              nonlin=None,
              c=1.0,
              **kwargs):
     super().__init__(*args, **kwargs)
     if self.bias is not None:
         if hyperbolic_bias:
             self.ball = manifold = PoincareBall(c=c)
             self.bias = ManifoldParameter(self.bias, manifold=manifold)
             with torch.no_grad():
                 self.bias.set_(pmath.expmap0(self.bias.normal_() / 4, c=c))
     with torch.no_grad():
         self.weight.normal_(std=1e-2)
     self.hyperbolic_bias = hyperbolic_bias
     self.hyperbolic_input = hyperbolic_input
     self.nonlin = nonlin
예제 #5
0
 def __init__(
     self,
     input_size,
     hidden_size,
     num_layers=1,
     bias=True,
     nonlin=None,
     hyperbolic_input=True,
     hyperbolic_hidden_state0=True,
     c=1.0,
 ):
     super().__init__()
     self.ball = PoincareBall(c=c)
     self.input_size = input_size
     self.hidden_size = hidden_size
     self.num_layers = num_layers
     self.bias = bias
     self.weight_ih = torch.nn.ParameterList([
         torch.nn.Parameter(
             torch.Tensor(3 * hidden_size,
                          input_size if i == 0 else hidden_size))
         for i in range(num_layers)
     ])
     self.weight_hh = torch.nn.ParameterList([
         torch.nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size))
         for _ in range(num_layers)
     ])
     if bias:
         biases = []
         for i in range(num_layers):
             bias = torch.randn(3, hidden_size) * 1e-5
             bias = ManifoldParameter(pmath.expmap0(bias, c=self.ball.c),
                                      manifold=self.ball)
             biases.append(bias)
         self.bias = torch.nn.ParameterList(biases)
     else:
         self.register_buffer("bias", None)
     self.nonlin = nonlin
     self.hyperbolic_input = hyperbolic_input
     self.hyperbolic_hidden_state0 = hyperbolic_hidden_state0
     self.reset_parameters()
예제 #6
0
    def forward(self, input):

        source_input = input[0][0]
        # print(source_input)
        target_input = input[0][1]
        alignment = input[1]
        batch_size = alignment.shape[0]

        source_input_data = self.embedding(source_input.data)
        target_input_data = self.embedding(target_input.data)

        zero_hidden = torch.zeros(self.num_layers,
                                  batch_size,
                                  self.hidden_dim,
                                  device=self.device or source_input.device,
                                  dtype=source_input_data.dtype)
        # print(self.cell_type)
        if self.embedding_type == "eucl" and "hyp" in self.cell_type:  # This is for the example
            # print(source_input_data.shape)
            source_input_data = pmath.expmap0(source_input_data, c=self.c)
            # print(source_input_data.shape)
            target_input_data = pmath.expmap0(target_input_data, c=self.c)
        elif self.embedding_type == "hyp" and "eucl" in self.cell_type:
            source_input_data = pmath.logmap0(source_input_data, c=self.c)
            target_input_data = pmath.logmap0(target_input_data, c=self.c)
        # ht: (num_layers * num_directions, batch, hidden_size)

        # print(source_input.batch_sizes.shape)
        source_input = torch.nn.utils.rnn.PackedSequence(
            source_input_data, source_input.batch_sizes)
        target_input = torch.nn.utils.rnn.PackedSequence(
            target_input_data, target_input.batch_sizes)

        _, source_hidden = self.cell_source(source_input, zero_hidden)
        _, target_hidden = self.cell_target(target_input, zero_hidden)

        # take hiddens from the last layer
        source_hidden = source_hidden[-1]
        # print(target_hidden)
        target_hidden = target_hidden[-1][alignment]
        # print(alignment)
        # print(target_hidden)

        if self.decision_type == "hyp":
            if "eucl" in self.cell_type:
                source_hidden = pmath.expmap0(source_hidden, c=self.c)
                target_hidden = pmath.expmap0(target_hidden, c=self.c)
            source_projected = self.projector_source(source_hidden)
            target_projected = self.projector_target(target_hidden)
            projected = pmath.mobius_add(source_projected,
                                         target_projected,
                                         c=self.ball.c)
            if self.use_distance_as_feature:
                dist = (pmath.dist(source_hidden,
                                   target_hidden,
                                   dim=-1,
                                   keepdim=True,
                                   c=self.ball.c)**2)
                bias = pmath.mobius_scalar_mul(dist,
                                               self.dist_bias,
                                               c=self.ball.c)
                projected = pmath.mobius_add(projected, bias, c=self.ball.c)
        else:
            if "hyp" in self.cell_type:
                source_hidden = pmath.logmap0(source_hidden, c=self.c)
                target_hidden = pmath.logmap0(target_hidden, c=self.c)
            projected = self.projector(
                torch.cat((source_hidden, target_hidden), dim=-1))
            if self.use_distance_as_feature:
                dist = torch.sum((source_hidden - target_hidden).pow(2),
                                 dim=-1,
                                 keepdim=True)
                bias = self.dist_bias * dist
                projected = projected + bias

        logits = self.logits(projected)
        # CrossEntropy accepts logits
        return logits
예제 #7
0
    def __init__(
        self,
        vocab_size,
        embedding_dim,
        hidden_dim,
        project_dim,
        cell_type="eucl_rnn",
        embedding_type="eucl",
        decision_type="eucl",
        use_distance_as_feature=True,
        device=None,
        num_layers=1,
        num_classes=1,
        c=1.0,
    ):
        super(RNNBase, self).__init__()
        (cell_type, embedding_type,
         decision_type) = map(str.lower,
                              [cell_type, embedding_type, decision_type])
        if embedding_type == "eucl":
            self.embedding = LookupEmbedding(vocab_size,
                                             embedding_dim,
                                             manifold=Euclidean())
            with torch.no_grad():
                self.embedding.weight.normal_()
        elif embedding_type == "hyp":
            self.embedding = LookupEmbedding(
                vocab_size,
                embedding_dim,
                manifold=PoincareBall(c=c),
            )
            with torch.no_grad():
                self.embedding.weight.set_(
                    pmath.expmap0(self.embedding.weight.normal_() / 10,
                                  c=c)  # Q
                )
        else:
            raise NotImplementedError(
                "Unsuported embedding type: {0}".format(embedding_type))
        self.embedding_type = embedding_type
        if decision_type == "eucl":
            self.projector = nn.Linear(hidden_dim * 2, project_dim)  # Q
            self.logits = nn.Linear(project_dim, num_classes)
        elif decision_type == "hyp":
            self.projector_source = hyrnn.MobiusLinear(  # Q
                hidden_dim, project_dim, c=c)
            self.projector_target = hyrnn.MobiusLinear(  # Q
                hidden_dim, project_dim, c=c)
            self.logits = hyrnn.MobiusDist2Hyperplane(project_dim,
                                                      num_classes)  # Q
        else:
            raise NotImplementedError(
                "Unsuported decision type: {0}".format(decision_type))
        self.ball = PoincareBall(c)
        if use_distance_as_feature:
            if decision_type == "eucl":
                self.dist_bias = nn.Parameter(torch.zeros(project_dim))
            else:
                self.dist_bias = geoopt.ManifoldParameter(
                    torch.zeros(project_dim), manifold=self.ball)
        else:
            self.register_buffer("dist_bias", None)
        self.decision_type = decision_type
        self.use_distance_as_feature = use_distance_as_feature
        self.device = device  # declaring device here due to fact we are using catalyst
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.c = c

        if cell_type == "eucl_rnn":
            self.cell = nn.RNN
        elif cell_type == "eucl_gru":
            self.cell = nn.GRU
        elif cell_type == "hyp_gru":
            self.cell = functools.partial(MobiusGRU, c=c)
        else:
            raise NotImplementedError(
                "Unsuported cell type: {0}".format(cell_type))
        self.cell_type = cell_type

        self.cell_source = self.cell(embedding_dim, self.hidden_dim,
                                     self.num_layers)
        self.cell_target = self.cell(embedding_dim, self.hidden_dim,
                                     self.num_layers)
예제 #8
0
def main():
    args = parse_option()

    train_words_embd, train_img_mat, val_words_embd, val_img_mat = get_train_val_data(
        args)
    print('Get Data!')
    # Train
    model, criterion = set_model(args)
    optimizer = set_optimizer(args, model)

    model = model.cuda()
    criterion = criterion.cuda()
    val_words_embd = val_words_embd.cuda()
    val_img_mat = val_img_mat.cuda()

    total_step = 0

    train_losses = []
    train_f1s = []
    test_losses = []
    test_f1s = []
    counter = []

    for epoch in range(args.epochs):
        model.train()
        losses = []
        f1s = []
        avg_loss = 0
        f1 = 0

        # # tensor shuffle
        # shuffle_index       = torch.randperm(len(train_img_mat))
        # train_img_mat_sf    = train_img_mat
        # train_words_embd_sf = train_words_embd
        # for i in range(len(train_img_mat)):
        # train_img_mat_sf[i]    = train_img_mat[shuffle_index[i]]-1
        # train_words_embd_sf[i] = train_words_embd[shuffle_index[i]]-1

        for j in range(0, len(train_img_mat), args.batch_size):
            total_step += 1
            batch_train_img = train_img_mat[j:j + args.batch_size, :]
            batch_train_words = train_words_embd[j:j + args.batch_size, :]
            batch_train_img = torch.from_numpy(batch_train_img).type(
                torch.FloatTensor).cuda()
            batch_train_words = torch.from_numpy(batch_train_words).type(
                torch.FloatTensor).cuda()

            batch_train_img = pm.expmap0(batch_train_img, c=1)
            img_tranf = model(batch_train_img)
            # y_pred    = y_pred.squeeze(1)

            loss = criterion(img_tranf, batch_train_words)
            losses.append(loss.item())
            avg_loss = np.average(losses)  # running average
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            f1 = f1_score(batch_train_words.cpu().detach().numpy(),
                          img_tranf.cpu().detach().numpy().round(),
                          average='macro')
            f1s.append(f1)
            avg_f1 = np.average(f1s)  # running average

            if total_step % args.print_freq == 0:
                print("Epoch {}, Step {}, Loss {}".format(epoch, j, avg_loss))
                print("Epoch {}, Step {}, Loss {}, f1_score {}".format(
                    i, j, avg_loss, avg_f1))

        print('############# Evaluation ##############')
        model.eval()
        train_f1s.append(avg_f1)
        train_losses.append(avg_loss)

        val_img_mat = pm.expmap0(val_img_mat, c=1)
        img_tranf = model(val_img_mat)
        loss = criterion(img_tranf, val_words_embd)

        pre_recall_f1 = precision_recall_fscore_support(\
            val_words_embd.cpu().detach().numpy(), img_tranf.cpu().detach().numpy().round(), average='macro')

        if (test_losses != [] and loss.item() < min(test_losses)):
            torch.save(model, args.model_path)

        torch.cuda.empty_cache()

        test_losses.append(loss.item())
        counter.append(total_step)

        print("Loss {}".format(loss.item()))

        print("Loss {}, pre {}, recall {}, f1_score {}".format(
            loss.item(), pre_recall_f1[0], pre_recall_f1[1], pre_recall_f1[2]))

        if pre_recall_f1[2] > best_f1:
            best_f1 = pre_recall_f1[2]
            best_recall = pre_recall_f1[1]
            best_precision = pre_recall_f1[0]

        print('#######################################')

        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        if epoch % 20 == 0:
            plt.title('Cos_Loss')
            plt.xlabel('number of training examples seen')
            plt.ylabel('Cos_Loss')

            plt.plot(counter, train_losses, 'b', label='train')
            plt.plot(counter, test_losses, 'r', label='val')

            plt.legend(['Train Loss', 'Val Loss'], loc='upper right')
            plt.grid()
            plt.savefig(args.loss_path)

            plt.cla()
            plt.title('f1_score')
            plt.xlabel('number of training batches seenVal')
            plt.ylabel('f1_score')

            plt.plot(counter, train_f1s, 'b', label='train')
            plt.plot(counter, test_f1s, 'r', label='test')

            plt.legend(['Train f1_score', 'Test f1_score'], loc='upper right')
            plt.grid()
            plt.savefig('./pictures/f1.jpg')

    print("pre {}, recall {}, f1_score {}".format(best_precision, best_recall,
                                                  best_f1))

    plt.title('Cos_Loss')
    plt.xlabel('number of training examples seen')
    plt.ylabel('Cos_Loss')

    plt.plot(counter, train_losses, 'b', label='train')
    plt.plot(counter, test_losses, 'r', label='val')

    plt.legend(['Train Loss', 'Val Loss'], loc='upper right')
    plt.grid()
    plt.savefig(args.loss_path)

    plt.cla()
    plt.title('f1_score')
    plt.xlabel('number of training batches seenVal')
    plt.ylabel('f1_score')

    plt.plot(counter, train_f1s, 'b', label='train')
    plt.plot(counter, test_f1s, 'r', label='test')

    plt.legend(['Train f1_score', 'Test f1_score'], loc='upper right')
    plt.grid()
    plt.savefig('./pictures/f1.jpg')