Exemple #1
0
 def __init__(self, loader: DataLoader, mean: list, std: list):
     """
     Args:
         loader (DataLoader): data loader
         mean (list): mean value for each channel
         std (list): std value for each channel
     """
     super(DataPrefetcher, self).__init__()
     self.loader = iter(loader)
     self.stream = torch.cuda.Stream()
     # RGB channels
     self.mean = torch.new_tensor(mean).cuda().view(1, 3, 1, 1)
     self.std = torch.new_tensor(std).cuda().view(1, 3, 1, 1)
     self.preload()
Exemple #2
0
def maybe_copy(a, pin_memory):
    if isinstance(a, torch.Tensor):
        if pin_memory:
            return a.pin_memory()
        else:
            return torch.new_tensor(a)
    else:
        return a
Exemple #3
0
    def train(self):
        print("GPU/CPU:", torch.cuda.get_device_name(0))
        # Start timer
        model_file = open(self.out_folder + '/model_notes.txt', "w")
        start_time = time.time()
        training_notes_file = open(self.out_folder + '/training_notes.txt',
                                   "w")
        losses_file = open(self.out_folder + '/losses_notes.txt', "w")
        if self.dataset == 'cifar10':
            data = dset.CIFAR10(root=self.dataroot,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.Resize(self.image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5),
                                                         (0.5, 0.5, 0.5)),
                                ]))

        elif self.dataset == 'lsun':
            transform = transforms.Compose([
                transforms.Resize(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            dataset = dset.ImageFolder(root=self.dataroot, transform=transform)

    # dataset = dset.LSUN(opt.dataroot, classes=['bedroom_train'],
    #                     transform=transforms.Compose([
    #                         transforms.Resize(opt.imageSize),
    #                         transforms.CenterCrop(opt.imageSize),
    #                         transforms.ToTensor(),
    #                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    #                     ]))

        else:
            # Get dataset
            data = get_dataset(self.dataroot)

        # # Set the type of GAN
        dataloader = torch.utils.data.DataLoader(data,
                                                 batch_size=self.batch_size,
                                                 shuffle=True,
                                                 num_workers=int(self.workers))

        # # Set the type of GAN
        if self.type == "dcgan":
            if self.image_size == 32:
                self.generator = Generator32(self.z_noise, self.channels,
                                             self.num_gen_filters).to(
                                                 self.device)
                self.discriminator = Discriminator32(self.ngpu, self.channels,
                                                     self.num_disc_filters).to(
                                                         self.device)
            else:

                self.discriminator = DcganDiscriminator(
                    self.channels, self.num_disc_filters).to(self.device)
            criterion = nn.BCELoss()

        if self.type == "can":
            self.generator = Can64Generator(self.z_noise, self.channels,
                                            self.num_gen_filters).to(
                                                self.device)
            self.discriminator = Can64Discriminator(self.channels, self.y_dim,
                                                    self.num_disc_filters).to(
                                                        self.device)
            style_criterion = nn.CrossEntropyLoss()

        # setup optimizers
        # todo : options for SGD, RMSProp
        if self.smoothing:
            self.disc_optimizer = optim.SGD(
                self.discriminator.parameters(),
                lr=self.lr)  # betas=(self.beta1, 0.999))
        else:
            self.disc_optimizer = optim.Adam(
                self.discriminator.parameters(),
                lr=self.lr)  # betas=(self.beta1, 0.999))
        # recommended in GANhacks
        self.gen_optimizer = optim.Adam(
            self.generator.parameters(),
            lr=self.lr)  #, betas=(self.beta1, 0.999))

        if self.lsgan:
            criterion = nn.MSELoss()
        else:
            criterion = nn.BCELoss()
        self.discriminator.apply(self.weights_init)
        self.generator.apply(self.weights_init)

        if self.disc_path != '':
            self.discriminator.load_state_dict(torch.load(self.disc_path))

        model_file.write("Discriminator:\n")
        model_file.write(str(self.discriminator))
        model_file.write("\nGenerator:\n")
        model_file.write(str(self.generator))

        real_label = 1
        fake_label = 0

        # Normalized noise
        fixed_noise = torch.randn(self.batch_size,
                                  self.z_noise,
                                  1,
                                  1,
                                  device=self.device)

        # Generator class/style labels
        gen_style_labels = torch.new_ones(self.batch_size, device=self.device)

        # Actual training!
        for epoch in range(self.num_epochs):
            epoch_start_time = time.time()
            training_notes_file.write("\nEpoch" + str(epoch) + ":\n")
            data_iterator = iter(dataloader)

            i = 0
            # Heavily inspired by https://github.com/pytorch/examples/blob/master/dcgan/main.py
            while i < len(dataloader):

                disc_loss_epoch = []
                gen_loss_epoch = []

                if self.type == "can":
                    disc_class_loss_epoch = []
                    gen_class_loss_epoch = []

                curr_start = time.time()
                # WGAN

                # Train Discriminator
                self.discriminator.zero_grad()

                # real
                data = data_iterator.next()
                real_images, image_labels = data
                real_images = real_images.to(self.device)
                batch_size = real_images.size(0)
                real_image_labels = torch.LongTensor(batch_size).to(
                    self.device)
                real_image_labels.copy_(image_labels)

                # label smoothing
                # rand_labels = np.random.uniform(low=0.7, high=1.2, size=(batch_size,))
                # r_labels = torch.from_numpy(rand_labels)
                # labels.copy_(r_labels)
                #print(labels)
                if self.type == 'can':
                    predicted_output_real, predicted_styles_real = self.discriminator(
                        real_images.detach())
                    disc_class_loss = style_criterion(predicted_styles_real,
                                                      real_image_labels)
                    disc_class_loss.backward(retain_graph=True)
                else:
                    predicted_output_real = self.discriminator(
                        real_images.detach())

                if self.smoothing:
                    labels_real = []
                    labels_fake = []
                    for n in range(self.batch_size):
                        labels_real.append(random.uniform(0.7, 1.3))
                        labels_fake.append(random.uniform(0.0, 0.3))
                    labels_real = np.asarray(labels_real)
                    labels_fake = np.asarray(labels_fake)
                    if self.flip:
                        prob = random.uniform(0.0, 2.0)
                        if prob < 0.3:
                            labels = torch.new_tensor(labels_fake,
                                                      device=self.device)
                    else:
                        labels = torch.new_tensor(labels_real,
                                                  device=self.device)
            #labels= torch.full((self.batch_size,), labels_, device=self.device)

                else:
                    if self.flip:
                        prob = random.uniform(0.0, 2.0)
                        if prob < 0.3:
                            labels = torch.new_tensor(fake_label,
                                                      device=self.device)
                    else:
                        labels = torch.full((self.batch_size, ),
                                            real_label,
                                            device=self.device)

                disc_loss_real = criterion(predicted_output_real, labels)
                disc_loss_real.backward(retain_graph=True)
                disc_x = predicted_output_real.mean().item()

                # fake

                noise = torch.randn(batch_size,
                                    self.z_noise,
                                    1,
                                    1,
                                    device=self.device)

                fake_images = self.generator(noise)
                if self.flip:
                    prob = random.uniform(0.0, 2.0)
                    if prob < 0.3:
                        if self.smoothing:
                            labels = torch.new_tensor(labels_real)
                        else:
                            labels.fill_(real_label)
                else:
                    labels.fill_(fake_label)

                if self.type == 'can':
                    predicted_output_fake, predicted_styles_fake = self.discriminator(
                        fake_images.detach())

                else:
                    predicted_output_fake = self.discriminator(
                        fake_images.detach())

                disc_loss_fake = criterion(predicted_output_fake, labels)
                disc_loss_fake.backward(retain_graph=True)
                disc_gen_z_1 = predicted_output_fake.mean().item()
                disc_loss = disc_loss_real + disc_loss_fake

                if self.type == 'can':
                    disc_loss += disc_class_loss

                self.disc_optimizer.step()

                # train generator

                self.generator.zero_grad()
                labels.fill_(real_label)

                if self.type == 'can':
                    predicted_output_fake, predicted_styles_fake = self.discriminator(
                        fake_images)

                else:
                    predicted_output_fake = self.discriminator(fake_images)

                gen_loss = criterion(predicted_output_fake, labels)
                gen_loss.backward(retain_graph=True)
                disc_gen_z_2 = predicted_output_fake.mean().item()

                if self.type == 'can':
                    fake_batch_labels = 1.0 / self.y_dim * torch.ones_like(
                        predicted_styles_fake)
                    fake_batch_labels = torch.mean(fake_batch_labels,
                                                   1).long().to(self.device)
                    gen_class_loss = style_criterion(predicted_styles_fake,
                                                     fake_batch_labels)
                    gen_class_loss.backward()
                    gen_loss += gen_class_loss
                    #disc_loss += torch.log(gen_class_loss)

                self.gen_optimizer.step()

                disc_loss_epoch.append(disc_loss.item())
                gen_loss_epoch.append(gen_loss.item())

                if self.type == "can":
                    disc_class_loss_epoch.append(disc_class_loss.item())
                    gen_class_loss_epoch.append(gen_class_loss.item())

                if self.type == 'can':

                    print(
                        '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Class_D: %.4f Class_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                        % (epoch, self.num_epochs, i, len(dataloader),
                           disc_loss.item(), gen_loss.item(),
                           disc_class_loss.item(), gen_class_loss.item(),
                           disc_x, disc_gen_z_1, disc_gen_z_2))
                    # training_notes_file.write('\n[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Class_D: %.4f Class_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                    #     % (epoch, self.num_epochs, i, len(dataloader),
                    #         disc_loss.item(), gen_loss.item(), disc_class_loss.item(), gen_class_loss.item(), disc_x, disc_gen_z_1, disc_gen_z_2))
                else:
                    print(
                        '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f  D(x): %.4f D(G(z)): %.4f / %.4f'
                        % (epoch, self.num_epochs, i, len(dataloader),
                           disc_loss.item(), gen_loss.item(), disc_x,
                           disc_gen_z_1, disc_gen_z_2))
                    # training_notes_file.write('\n[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Class_D: %.4f Class_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                    #     % (epoch, self.num_epochs, i, len(dataloader),
                    #         disc_loss.item(), gen_loss.item(), disc_class_loss.item(), gen_class_loss.item(), disc_x, disc_gen_z_1, disc_gen_z_2))
                if (i > 0 and (i % 10000 == 0)) or i == (len(dataloader) - 1):
                    fake = self.generator(fixed_noise)
                    vutils.save_image(fake.data,
                                      '%s/fake_samples_epoch_%03d_%04d.jpg' %
                                      (self.out_folder, epoch, i),
                                      normalize=True)
                i += 1

            epoch_end_time = time.time()
            training_notes_file.write(
                "\nEpoch training time: {} seconds".format(epoch_end_time -
                                                           epoch_start_time))

            # Metrics for Floydhub
            print('{{"metric": "disc_loss", "value": {:.4f}}}'.format(
                np.mean(disc_loss_epoch)))
            print('{{"metric": "gen_loss", "value": {:.4f}}}'.format(
                np.mean(gen_loss_epoch)))

            training_notes_file.write("\nEpoch " + str(epoch) +
                                      " average losses:\n")
            training_notes_file.write("Discriminator loss: " +
                                      str(np.mean(disc_loss_epoch)))
            training_notes_file.write("\nGenerator loss: " +
                                      str(np.mean(gen_loss_epoch)))

            if self.type == 'can':
                print(
                    '{{"metric": "disc_class_loss", "value": {:.4f}}}'.format(
                        np.mean(disc_class_loss_epoch)))
                print('{{"metric": "gen_class_loss", "value": {:.4f}}}'.format(
                    np.mean(gen_class_loss_epoch)))
                training_notes_file.write(
                    "\nDiscriminator classification loss: " +
                    str(np.mean(disc_class_loss_epoch)))
                training_notes_file.write(
                    "\nGenerator `classification` loss: " +
                    str(np.mean(gen_class_loss_epoch)))

            # Get the mean of the losses over the epoch for the loss graphs
            disc_loss_epoch = np.asarray(disc_loss_epoch)
            gen_loss_epoch = np.asarray(gen_loss_epoch)
            if self.type == 'can':
                disc_class_loss_epoch = np.asarray(disc_class_loss_epoch)
                self.train_history['disc_class_loss'].append(
                    np.mean(disc_class_loss_epoch))
                gen_class_loss_epoch = np.asarray(gen_class_loss_epoch)
                self.train_history['gen_class_loss'].append(
                    np.mean(gen_class_loss_epoch))

            self.train_history['disc_loss'].append(np.mean(disc_loss_epoch))
            self.train_history['gen_loss'].append(np.mean(gen_loss_epoch))

            # do checkpointing
            if (epoch > 1 and
                (epoch % 5 == 0)) or (epoch == self.num_epochs - 1):
                torch.save(self.generator.state_dict(),
                           '%s/netG_epoch_%d.pth' % (self.out_folder, epoch))
                torch.save(self.discriminator.state_dict(),
                           '%s/netD_epoch_%d.pth' % (self.out_folder, epoch))

            training_notes_file.write(
                "\n---------------------------------------------------------------------------------\n"
            )

        training_notes_file.write(
            "\nTotal training time: {} seconds".format(time.time() -
                                                       start_time))
        model_file.close()
        training_notes_file.close()
        losses_file.write(str(self.train_history))
        losses_file.close()
Exemple #4
0
    def decode(self,
               input_word,
               input_char,
               input_pos,
               mask=None,
               beam=1,
               leading_symbolic=0):
        # reset noise for decoder
        self.decoder.reset_noise(0)

        # output_enc [batch, length, model_dim]
        # arc_c [batch, length, arc_space]
        # type_c [batch, length, type_space]
        # hn [num_direction, batch, hidden_size]
        output_enc, hn = self._get_encoder_output(input_word,
                                                  input_char,
                                                  input_pos,
                                                  mask=mask)
        enc_dim = output_enc.size(2)
        device = output_enc.device
        # output size [batch, length_encoder, arc_space]
        arc_c = self.activation(self.arc_c(output_enc))
        # output size [batch, length_encoder, type_space]
        type_c = self.activation(self.type_c(output_enc))
        type_space = type_c.size(2)
        # [decoder_layers, batch, hidden_size]
        hn = self._transform_decoder_init_state(hn)
        batch, max_len, _ = output_enc.size()

        heads = torch.zeros(batch,
                            1,
                            max_len,
                            device=device,
                            dtype=torch.int64)
        types = torch.zeros(batch,
                            1,
                            max_len,
                            device=device,
                            dtype=torch.int64)

        num_steps = 2 * max_len - 1
        stacked_heads = torch.zeros(batch,
                                    1,
                                    num_steps + 1,
                                    device=device,
                                    dtype=torch.int64)
        siblings = torch.zeros(
            batch, 1, num_steps +
            1, device=device, dtype=torch.int64) if self.sibling else None
        hypothesis_scores = output_enc.new_zeros((batch, 1))

        # [batch, beam, length]
        children = torch.arange(max_len, device=device,
                                dtype=torch.int64).view(1, 1, max_len).expand(
                                    batch, beam, max_len)
        constraints = torch.zeros(batch,
                                  1,
                                  max_len,
                                  device=device,
                                  dtype=torch.bool)
        constraints[:, :, 0] = True
        # [batch, 1]
        batch_index = torch.arange(batch, device=device,
                                   dtype=torch.int64).view(batch, 1)

        # compute lengths
        if mask is None:
            steps = torch.new_tensor([num_steps] * batch,
                                     dtype=torch.int64,
                                     device=device)
            mask_sent = torch.ones(batch,
                                   1,
                                   max_len,
                                   dtype=torch.bool,
                                   device=device)
        else:
            steps = (mask.sum(dim=1) * 2 - 1).long()
            mask_sent = mask.unsqueeze(1).bool()

        num_hyp = 1
        mask_hyp = torch.ones(batch, 1, device=device)
        hx = hn
        for t in range(num_steps):
            # [batch, num_hyp]
            curr_heads = stacked_heads[:, :, t]
            curr_gpars = heads.gather(dim=2,
                                      index=curr_heads.unsqueeze(2)).squeeze(2)
            curr_sibs = siblings[:, :, t] if self.sibling else None
            # [batch, num_hyp, enc_dim]
            src_encoding = output_enc.gather(
                dim=1,
                index=curr_heads.unsqueeze(2).expand(batch, num_hyp, enc_dim))

            if self.sibling:
                mask_sib = curr_sibs.gt(0).float().unsqueeze(2)
                output_enc_sibling = output_enc.gather(
                    dim=1,
                    index=curr_sibs.unsqueeze(2).expand(
                        batch, num_hyp, enc_dim)) * mask_sib
                src_encoding = src_encoding + output_enc_sibling

            if self.grandPar:
                output_enc_gpar = output_enc.gather(
                    dim=1,
                    index=curr_gpars.unsqueeze(2).expand(
                        batch, num_hyp, enc_dim))
                src_encoding = src_encoding + output_enc_gpar

            # transform to decoder input
            # [batch, num_hyp, dec_dim]
            src_encoding = self.activation(self.src_dense(src_encoding))

            # output [batch * num_hyp, dec_dim]
            # hx [decoder_layer, batch * num_hyp, dec_dim]
            output_dec, hx = self.decoder.step(src_encoding.view(
                batch * num_hyp, -1),
                                               hx=hx)
            dec_dim = output_dec.size(1)
            # [batch, num_hyp, dec_dim]
            output_dec = output_dec.view(batch, num_hyp, dec_dim)

            # [batch, num_hyp, arc_space]
            arc_h = self.activation(self.arc_h(output_dec))
            # [batch, num_hyp, type_space]
            type_h = self.activation(self.type_h(output_dec))
            # [batch, num_hyp, length]
            out_arc = self.biaffine(arc_h,
                                    arc_c,
                                    mask_query=mask_hyp,
                                    mask_key=mask)
            # mask invalid position to -inf for log_softmax
            if mask is not None:
                minus_mask_enc = mask.eq(0).unsqueeze(1)
                out_arc.masked_fill_(minus_mask_enc, float('-inf'))

            # [batch]
            mask_last = steps.le(t + 1)
            mask_stop = steps.le(t)
            minus_mask_hyp = mask_hyp.eq(0).unsqueeze(2)
            # [batch, num_hyp, length]
            hyp_scores = F.log_softmax(out_arc, dim=2).masked_fill_(
                mask_stop.view(batch, 1, 1) + minus_mask_hyp, 0)
            # [batch, num_hyp, length]
            hypothesis_scores = hypothesis_scores.unsqueeze(2) + hyp_scores

            # [batch, num_hyp, length]
            mask_leaf = curr_heads.unsqueeze(2).eq(
                children[:, :num_hyp]) * mask_sent
            mask_non_leaf = (~mask_leaf) * mask_sent

            # apply constrains to select valid hyps
            # [batch, num_hyp, length]
            mask_leaf = mask_leaf * (mask_last.unsqueeze(1) +
                                     curr_heads.ne(0)).unsqueeze(2)
            mask_non_leaf = mask_non_leaf * (~constraints)

            hypothesis_scores.masked_fill_(~(mask_non_leaf + mask_leaf),
                                           float('-inf'))
            # [batch, num_hyp * length]
            hypothesis_scores, hyp_index = torch.sort(hypothesis_scores.view(
                batch, -1),
                                                      dim=1,
                                                      descending=True)

            # [batch]
            prev_num_hyp = num_hyp
            num_hyps = (mask_leaf + mask_non_leaf).long().view(batch,
                                                               -1).sum(dim=1)
            num_hyp = num_hyps.max().clamp(max=beam).item()
            # [batch, hum_hyp]
            hyps = torch.arange(num_hyp, device=device,
                                dtype=torch.int64).view(1, num_hyp)
            mask_hyp = hyps.lt(num_hyps.unsqueeze(1)).float()

            # [batch, num_hyp]
            hypothesis_scores = hypothesis_scores[:, :num_hyp]
            hyp_index = hyp_index[:, :num_hyp]
            base_index = hyp_index / max_len
            child_index = hyp_index % max_len

            # [batch, num_hyp]
            hyp_heads = curr_heads.gather(dim=1, index=base_index)
            hyp_gpars = curr_gpars.gather(dim=1, index=base_index)

            # [batch, num_hyp, length]
            base_index_expand = base_index.unsqueeze(2).expand(
                batch, num_hyp, max_len)
            constraints = constraints.gather(dim=1, index=base_index_expand)
            constraints.scatter_(2, child_index.unsqueeze(2), True)

            # [batch, num_hyp]
            mask_leaf = hyp_heads.eq(child_index)
            # [batch, num_hyp, length]
            heads = heads.gather(dim=1, index=base_index_expand)
            heads.scatter_(
                2, child_index.unsqueeze(2),
                torch.where(mask_leaf, hyp_gpars, hyp_heads).unsqueeze(2))
            types = types.gather(dim=1, index=base_index_expand)
            # [batch, num_hyp]
            org_types = types.gather(dim=2,
                                     index=child_index.unsqueeze(2)).squeeze(2)

            # [batch, num_hyp, num_steps]
            base_index_expand = base_index.unsqueeze(2).expand(
                batch, num_hyp, num_steps + 1)
            stacked_heads = stacked_heads.gather(dim=1,
                                                 index=base_index_expand)
            stacked_heads[:, :, t + 1] = torch.where(mask_leaf, hyp_gpars,
                                                     child_index)
            if self.sibling:
                siblings = siblings.gather(dim=1, index=base_index_expand)
                siblings[:, :,
                         t + 1] = torch.where(mask_leaf, child_index,
                                              torch.zeros_like(child_index))

            # [batch, num_hyp, type_space]
            base_index_expand = base_index.unsqueeze(2).expand(
                batch, num_hyp, type_space)
            child_index_expand = child_index.unsqueeze(2).expand(
                batch, num_hyp, type_space)
            # [batch, num_hyp, num_labels]
            out_type = self.bilinear(
                type_h.gather(dim=1, index=base_index_expand),
                type_c.gather(dim=1, index=child_index_expand))
            hyp_type_scores = F.log_softmax(out_type, dim=2)
            # compute the prediction of types [batch, num_hyp]
            hyp_type_scores, hyp_types = hyp_type_scores.max(dim=2)
            hypothesis_scores = hypothesis_scores + hyp_type_scores.masked_fill_(
                mask_stop.view(batch, 1), 0)
            types.scatter_(
                2, child_index.unsqueeze(2),
                torch.where(mask_leaf, org_types, hyp_types).unsqueeze(2))

            # hx [decoder_layer, batch * num_hyp, dec_dim]
            # hack to handle LSTM
            hx_index = (base_index + batch_index * prev_num_hyp).view(batch *
                                                                      num_hyp)
            if isinstance(hx, tuple):
                hx, cx = hx
                hx = hx[:, hx_index]
                cx = cx[:, hx_index]
                hx = (hx, cx)
            else:
                hx = hx[:, hx_index]

        heads = heads[:, 0].cpu().numpy()
        types = types[:, 0].cpu().numpy()
        return heads, types
Exemple #5
0
    def decode(self,
               input_word,
               input_char,
               input_bert,
               input_pos,
               mask=None,
               beam=1,
               leading_symbolic=0):
        def creates_cycle(index, arcs):
            head = arcs[index]
            if head == 0: return False
            iter = len(arcs) + 1
            elto = arcs[head]
            while iter > 0:
                if elto == 0: return False
                if elto == index:
                    return True
                elto = arcs[elto]
                iter -= 1
            return False

        def hasCycles(A, head, dep):
            if head == dep: return True
            aux = set(A)
            aux.add((head, dep))
            if count_cycles(aux) != 0: return True
            return False

        def count_cycles(A):
            d = {}
            for a, b in A:
                if a not in d:
                    d[a] = [b]
                else:
                    d[a].append(b)
            return sum([1 for e in tarjan(d) if len(e) > 1])

        def is_nonproj(A, head_node, node):
            left_node = int(node)
            right_node = int(head_node)
            if int(node) > int(head_node):
                left_node = int(head_node)
                right_node = int(node)
            for head, index in A:
                left = int(index)
                right = int(head)
                if int(index) > int(head):
                    left = int(head)
                    right = int(index)
                if (left < left_node and left_node < right
                        and right < right_node) or (left_node < left
                                                    and left < right_node
                                                    and right_node < right):
                    return True
            return False

        debug = False

        # reset noise for decoder
        self.decoder.reset_noise(0)

        # output_enc [batch, length, model_dim]
        # arc_c [batch, length, arc_space]
        # type_c [batch, length, type_space]
        # hn [num_direction, batch, hidden_size]
        output_enc, hn = self._get_encoder_output(input_word,
                                                  input_char,
                                                  input_bert,
                                                  input_pos,
                                                  mask=mask)
        enc_dim = output_enc.size(2)
        device = output_enc.device
        # output size [batch, length_encoder, arc_space]
        arc_c = self.activation(self.arc_c(output_enc))
        # output size [batch, length_encoder, type_space]
        type_c = self.activation(self.type_c(output_enc))
        type_space = type_c.size(2)
        # [decoder_layers, batch, hidden_size]
        hn = self._transform_decoder_init_state(hn)
        batch, max_len, _ = output_enc.size()

        heads = torch.zeros(batch,
                            1,
                            max_len,
                            device=device,
                            dtype=torch.int64)
        types = torch.zeros(batch,
                            1,
                            max_len,
                            device=device,
                            dtype=torch.int64)

        num_steps = max_len - 1

        stacked_heads = torch.ones(
            batch, 1, num_steps + 1, device=device,
            dtype=torch.int64)  #Starts in position 1, instead of 0

        hypothesis_scores = output_enc.new_zeros((batch, 1))

        # [batch, beam, length]
        children = torch.arange(max_len, device=device,
                                dtype=torch.int64).view(1, 1, max_len).expand(
                                    batch, beam, max_len)
        constraints = torch.zeros(batch,
                                  1,
                                  max_len,
                                  device=device,
                                  dtype=torch.bool)

        constraints[:, :, 0] = True

        # [batch, 1]
        batch_index = torch.arange(batch, device=device,
                                   dtype=torch.int64).view(batch, 1)

        # compute lengths
        if mask is None:
            steps = torch.new_tensor([num_steps] * batch,
                                     dtype=torch.int64,
                                     device=device)
            mask_sent = torch.ones(batch,
                                   1,
                                   max_len,
                                   dtype=torch.bool,
                                   device=device)
        else:

            steps = (mask.sum(dim=1) - 1).long()
            mask_sent = mask.unsqueeze(1).bool()

        num_hyp = 1
        mask_hyp = torch.ones(batch, 1, device=device)
        hx = hn
        for t in range(num_steps):
            if debug: print(t, '---------------------')
            # [batch, num_hyp]
            curr_heads = stacked_heads[:, :, t]

            if debug: print('CURHEAD', curr_heads)  #, stacked_heads)

            #NOT USED
            curr_gpars = heads.gather(dim=2,
                                      index=curr_heads.unsqueeze(2)).squeeze(2)
            #curr_sibs = siblings[:, :, t] if self.sibling else None

            # [batch, num_hyp, enc_dim]
            src_encoding = output_enc.gather(
                dim=1,
                index=curr_heads.unsqueeze(2).expand(batch, num_hyp, enc_dim))
            """
            NOT USED
            if self.sibling:
                mask_sib = curr_sibs.gt(0).float().unsqueeze(2)
                output_enc_sibling = output_enc.gather(dim=1, index=curr_sibs.unsqueeze(2).expand(batch, num_hyp, enc_dim)) * mask_sib
                src_encoding = src_encoding + output_enc_sibling

            if self.grandPar:
                output_enc_gpar = output_enc.gather(dim=1, index=curr_gpars.unsqueeze(2).expand(batch, num_hyp, enc_dim))
                src_encoding = src_encoding + output_enc_gpar
            """

            # transform to decoder input
            # [batch, num_hyp, dec_dim]
            src_encoding = self.activation(self.src_dense(src_encoding))

            # output [batch * num_hyp, dec_dim]
            # hx [decoder_layer, batch * num_hyp, dec_dim]
            output_dec, hx = self.decoder.step(src_encoding.view(
                batch * num_hyp, -1),
                                               hx=hx)
            dec_dim = output_dec.size(1)
            # [batch, num_hyp, dec_dim]
            output_dec = output_dec.view(batch, num_hyp, dec_dim)

            # [batch, num_hyp, arc_space]
            arc_h = self.activation(self.arc_h(output_dec))
            # [batch, num_hyp, type_space]
            type_h = self.activation(self.type_h(output_dec))
            # [batch, num_hyp, length]
            out_arc = self.biaffine(arc_h,
                                    arc_c,
                                    mask_query=mask_hyp,
                                    mask_key=mask)
            # mask invalid position to -inf for log_softmax
            if mask is not None:
                minus_mask_enc = mask.eq(0).unsqueeze(1)
                out_arc.masked_fill_(minus_mask_enc, float('-inf'))

            # [batch]

            mask_last = steps.le(t + 1)
            mask_stop = steps.le(t)

            minus_mask_hyp = mask_hyp.eq(0).unsqueeze(2)
            # [batch, num_hyp, length]
            hyp_scores = F.log_softmax(out_arc, dim=2).masked_fill_(
                mask_stop.view(batch, 1, 1) + minus_mask_hyp, 0)
            # [batch, num_hyp, length]
            hypothesis_scores = hypothesis_scores.unsqueeze(2) + hyp_scores

            #UNLABELLED PARSING
            # [batch, num_hyp, length]

            mask_leaf = curr_heads.unsqueeze(2).eq(
                children[:, :num_hyp]) * mask_sent

            #REMOVED POINTER SORTER. No sacamo la current focus word de las posibles posiciones a apuntar
            mask_non_leaf = (~mask_leaf) * mask_sent

            constraints = constraints * (~mask_stop).unsqueeze(1).unsqueeze(2)

            mask_non_leaf = (mask_non_leaf + mask_leaf) * (~constraints)

            hypothesis_scores.masked_fill_(~(mask_non_leaf), float('-inf'))
            # [batch, num_hyp * length]
            hypothesis_scores, hyp_index = torch.sort(hypothesis_scores.view(
                batch, -1),
                                                      dim=1,
                                                      descending=True)

            # [batch]
            prev_num_hyp = num_hyp

            num_hyps = (mask_non_leaf).long().view(batch, -1).sum(dim=1)
            num_hyp = num_hyps.max().clamp(max=beam).item()

            if debug: print(num_hyp, num_hyps, beam)
            # [batch, hum_hyp]
            hyps = torch.arange(num_hyp, device=device,
                                dtype=torch.int64).view(1, num_hyp)
            mask_hyp = hyps.lt(num_hyps.unsqueeze(1)).float()

            # [batch, num_hyp]
            hypothesis_scores = hypothesis_scores[:, :num_hyp]
            hyp_index = hyp_index[:, :num_hyp]
            base_index = hyp_index / max_len
            base_index = torch.floor_divide(hyp_index, max_len)

            child_index = hyp_index % max_len

            # [batch, num_hyp]
            hyp_heads = curr_heads.gather(dim=1, index=base_index)
            hyp_gpars = curr_gpars.gather(dim=1, index=base_index)

            # [batch, num_hyp, length]
            base_index_expand = base_index.unsqueeze(2).expand(
                batch, num_hyp, max_len)

            constraints = constraints.gather(dim=1, index=base_index_expand)
            constraints.scatter_(2, child_index.unsqueeze(2), True)

            heads = heads.gather(dim=1, index=base_index_expand)

            heads.scatter_(2, hyp_heads.unsqueeze(2), child_index.unsqueeze(
                2))  # es equivalente a heads[head]=child_index

            types = types.gather(dim=1, index=base_index_expand)
            # [batch, num_hyp]
            org_types = types.gather(dim=2,
                                     index=child_index.unsqueeze(2)).squeeze(2)

            # [batch, num_hyp, num_steps]
            base_index_expand = base_index.unsqueeze(2).expand(
                batch, num_hyp, num_steps + 1)

            stacked_heads = stacked_heads.gather(dim=1,
                                                 index=base_index_expand)

            stacked_heads[:, :, t + 1] = stacked_heads[:, :, t] + 1
            """
            if self.sibling:
                siblings = siblings.gather(dim=1, index=base_index_expand)
                siblings[:, :, t + 1] = torch.where(mask_leaf, child_index, torch.zeros_like(child_index))
            """

            #LABELLER
            # [batch, num_hyp, type_space]
            base_index_expand = base_index.unsqueeze(2).expand(
                batch, num_hyp, type_space)
            child_index_expand = child_index.unsqueeze(2).expand(
                batch, num_hyp, type_space)
            # [batch, num_hyp, num_labels]
            out_type = self.bilinear(
                type_h.gather(dim=1, index=base_index_expand),
                type_c.gather(dim=1, index=child_index_expand))
            hyp_type_scores = F.log_softmax(out_type, dim=2)
            # compute the prediction of types [batch, num_hyp]
            hyp_type_scores, hyp_types = hyp_type_scores.max(dim=2)
            hypothesis_scores = hypothesis_scores + hyp_type_scores.masked_fill_(
                mask_stop.view(batch, 1), 0)

            types.scatter_(2, hyp_heads.unsqueeze(2), hyp_types.unsqueeze(2))

            # hx [decoder_layer, batch * num_hyp, dec_dim]
            # hack to handle LSTM
            hx_index = (base_index + batch_index * prev_num_hyp).view(batch *
                                                                      num_hyp)
            if isinstance(hx, tuple):
                hx, cx = hx
                hx = hx[:, hx_index]
                cx = cx[:, hx_index]
                hx = (hx, cx)
            else:
                hx = hx[:, hx_index]

        heads = heads[:, 0].cpu().numpy()
        types = types[:, 0].cpu().numpy()
        """
        #REMOVE CYCLES
        if self.remove_cycles:
            for head in heads:
                for elto in reversed(range(len(head))):
                    if creates_cycle(elto, head):
                        #print('CICLO', elto)
                        #for i,e in enumerate(head):
                        #    print(e,'->',i)
                        head[elto]=0
        """

        return heads, types