コード例 #1
0
class SAJEM():
    '''
    Self-Attention based Joint Embedding Model
    Consist of 2 branches to encode image and text
    '''
    def __init__(self,
                 image_encoder,
                 text_encoder,
                 image_mha,
                 bert_model,
                 optimizer='adam',
                 lr=1e-3,
                 l2_regularization=1e-2,
                 margin_loss=1e-2,
                 max_violation=True,
                 cost_style='mean',
                 use_lr_scheduler=False,
                 grad_clip=0,
                 num_training_steps=30000,
                 device='cuda'):
        self.image_mha = image_mha
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.bert_model = bert_model
        self.device = device

        self.use_lr_scheduler = use_lr_scheduler
        self.params = []
        self.params = list(self.image_mha.parameters())
        self.params += list(self.text_encoder.parameters())
        self.params += list(self.image_encoder.parameters())
        self.params += list(self.bert_model.parameters())
        self.grad_clip = grad_clip
        self.frozen = False
        if optimizer == 'adamW':
            self.optimizer = AdamW([{
                'params':
                list(self.bert_model.parameters()),
                'lr':
                3e-5
            }, {
                'params':
                list(self.image_encoder.parameters()) +
                list(self.text_encoder.parameters()) +
                list(self.image_mha.parameters()),
                'lr':
                1e-4
            }])
        elif optimizer == 'adam':
            self.optimizer = torch.optim.Adam([{
                'params':
                list(self.bert_model.parameters()),
                'lr':
                3e-5
            }, {
                'params':
                list(self.image_encoder.parameters()) +
                list(self.text_encoder.parameters()) +
                list(self.image_mha.parameters()),
                'lr':
                1e-4
            }])

            # self.optimizer = torch.optim.Adam([{'params':list(self.bert_model.parameters()),'lr':3e-5},
            #                     {'params':list(self.text_encoder.parameters()) + list(self.image_mha.parameters()),'lr':1e-4}])

        if self.use_lr_scheduler:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=100,
                num_training_steps=num_training_steps)
        self.lr_scheduler_0 = get_constant_schedule(self.optimizer)
        # loss
        self.mrl_loss = MarginRankingLoss(margin=margin_loss,
                                          max_violation=max_violation,
                                          cost_style=cost_style,
                                          direction='bidir')

    def forward(self, image_feature, image_attention_mask, input_ids,
                attention_mask, epoch):
        if epoch > 1 and self.frozen:
            self.frozen = False
            del self.lr_scheduler_0
            torch.cuda.empty_cache()

        image_feature = l2norm(image_feature).detach()
        final_image_features = l2norm(
            self.image_mha(image_feature, image_attention_mask))
        text_feature = self.bert_model(input_ids,
                                       attention_mask=attention_mask)
        text_feature = l2norm(text_feature)
        if epoch == 1:
            text_feature = text_feature.detach()
            self.frozen = True
        image_to_common = self.image_encoder(final_image_features)
        # image_to_common = final_image_features
        text_to_common = self.text_encoder(text_feature)
        return image_to_common, text_to_common

    def save_network(self, folder):
        torch.save(self.image_mha.state_dict(),
                   os.path.join(folder, 'image_mha.pth'))
        torch.save(self.text_encoder.state_dict(),
                   os.path.join(folder, 'text_encoder.pth'))
        torch.save(self.image_encoder.state_dict(),
                   os.path.join(folder, 'image_encoder.pth'))
        torch.save(self.bert_model.state_dict(),
                   os.path.join(folder, 'bert_model.pth'))
        torch.save(self.optimizer.state_dict(),
                   os.path.join(folder, 'optimizer.pth'))
        if self.use_lr_scheduler:
            torch.save(self.lr_scheduler.state_dict(),
                       os.path.join(folder, 'scheduler.pth'))

    def switch_to_train(self):
        self.image_mha.train()
        self.text_encoder.train()
        self.image_encoder.train()
        self.bert_model.train()

    def switch_to_eval(self):
        self.image_mha.eval()
        self.text_encoder.eval()
        self.image_encoder.eval()
        self.bert_model.eval()

    def train(self, image_features, image_attention_mask, input_ids,
              attention_mask, epoch):
        self.switch_to_train()
        image_to_common, text_to_common = self.forward(image_features,
                                                       image_attention_mask,
                                                       input_ids,
                                                       attention_mask, epoch)
        self.optimizer.zero_grad()

        # Compute loss
        loss = self.mrl_loss(text_to_common, image_to_common)
        loss.backward()
        if self.grad_clip > 0:
            torch.nn.utils.clip_grad.clip_grad_norm_(self.params,
                                                     self.grad_clip)

        self.optimizer.step()
        return loss.item()

    def step_scheduler(self):
        if self.use_lr_scheduler and not self.frozen:
            self.lr_scheduler.step()
        else:
            self.lr_scheduler_0.step()

    def evaluate(self, val_image_dataloader, val_text_dataloader, k):
        self.switch_to_eval()
        # Load image features
        with torch.no_grad():
            image_features = []
            image_ids = []
            for ids, features, image_attention_mask in val_image_dataloader:
                image_ids.append(torch.stack(ids))
                features = torch.stack(features).to(self.device)
                image_attention_mask = torch.stack(image_attention_mask).to(
                    self.device)
                features = l2norm(features).detach()
                mha_features = l2norm(
                    self.image_mha(features, image_attention_mask))
                image_features.append(self.image_encoder(mha_features))
                # image_features.append(mha_features)
            image_features = torch.cat(image_features, dim=0)
            image_ids = torch.cat(image_ids, dim=0).to(self.device)
            # Evaluate
            recall = 0
            total_query = 0
            pbar = tqdm(enumerate(val_text_dataloader),
                        total=len(val_text_dataloader),
                        leave=False,
                        position=0,
                        file=sys.stdout)
            for i, (image_files, input_ids, attention_mask) in pbar:
                input_ids = input_ids.to(self.device)
                attention_mask = attention_mask.to(self.device)
                text_features = self.bert_model(input_ids,
                                                attention_mask=attention_mask)
                text_features = l2norm(text_features)
                text_features = self.text_encoder(text_features)
                image_files = torch.tensor(
                    list(
                        map(lambda x: int(re.findall(r'\d{12}', x)[0]),
                            image_files))).to(device)
                top_k = get_top_k_eval(text_features, image_features, k)
                for idx, indices in enumerate(top_k):
                    total_query += 1
                    true_image_id = image_files[idx]
                    top_k_images = torch.gather(image_ids, 0, indices)
                    if (top_k_images == true_image_id).nonzero().numel() > 0:
                        recall += 1
            recall = recall / total_query
            return recall
コード例 #2
0
class Model:
    def __init__(self, epochs=50, fc=FC_62):
        self.epochs = epochs
        self.model = CNN(fc)
        self.model.to(device)
        self.num_epochs = epochs
        self.epochs = 0
        self.loss = 0
        self.optimizer = AdamW(params=self.model.parameters())
        self.loss_fn = nn.CrossEntropyLoss()
        self.transform2 = [
            # transforms.CenterCrop(256),
            # Crop(28),
            # transforms.Resize(256),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ]

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epochs = checkpoint['epoch']
        self.loss = checkpoint['loss']
        print(f'\nmodel loaded from path : {path}')

    def save(self, epoch, model, optimizer, loss, path):
        save_path = root_dir + '/models/'
        if os.path.isdir(save_path) == False:
            os.makedirs(save_path)
        path = save_path + path
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, path)
        print(f'\nsaved model to path : {path}')

    def test(self, testloader, progress, type='validation'):
        print(f'Starting testing on {type} dataset')
        print('-------------------------------')
        correct, total = 0, 0
        with torch.no_grad():
            for i, data in enumerate(testloader, 0):
                inputs, targets = data
                inputs, targets = inputs.to(device), targets.to(device)

                outputs = self.model(inputs)

                _, predicted = torch.max(outputs.data, 1)
                # print(predicted)
                # print(targets)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                progress.update(self.batch_size)

            print(
                f'\nAccuracy on {type} dataset : {correct} / {total} = {100.0 * correct / total}'
            )
            print('--------------------------------')

            return 100.0 * correct / total

    def train(self, trainloader, epoch, progress):
        print(f'\nStarting epoch {epoch+1}')
        current_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            self.optimizer.zero_grad()

            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)
            loss.backward()

            self.optimizer.step()
            current_loss += loss.item()
            progress.update(self.batch_size)

        print(f'\nloss at epoch {epoch + 1} : {current_loss}')
        return current_loss

    def train_validate(self,
                       name,
                       mnist=False,
                       batch_size=64,
                       validation_split=0,
                       save_name=None):
        self.batch_size = batch_size

        if save_name is None:
            save_name = name

        progress = None
        np.random.seed(42)

        epochs_plot = []
        accuracy_plot = []
        loss_plot = []

        for epoch in range(0, self.num_epochs):
            if mnist:
                self.transform1 = [
                    transforms.RandomRotation(degrees=10),
                ]
                train_data = torchvision.datasets.MNIST(
                    'mnist',
                    download=True,
                    transform=transforms.Compose(self.transform1 +
                                                 self.transform2))
                trainloader = torch.utils.data.DataLoader(
                    train_data, batch_size=self.batch_size, num_workers=2)
                dataset_size = len(trainloader.dataset)
            else:
                data = get_data_set(name)
                dataset_size = len(data)
                ids = list(range(dataset_size))
                split = int(np.floor(validation_split * dataset_size))
                np.random.shuffle(ids)
                train_ids, val_ids = ids[split:], ids[:split]

                train_subsampler = torch.utils.data.SubsetRandomSampler(
                    train_ids)
                test_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)

                trainloader = torch.utils.data.DataLoader(
                    data,
                    batch_size=batch_size,
                    sampler=train_subsampler,
                    num_workers=2)
                testloader = torch.utils.data.DataLoader(
                    data,
                    batch_size=batch_size,
                    sampler=test_subsampler,
                    num_workers=2)
            if progress is None:
                progress = tqdm.tqdm(total=(2 + validation_split) *
                                     dataset_size * self.num_epochs,
                                     position=0,
                                     leave=True)
            current_loss = self.train(trainloader, epoch, progress)
            accuracy = self.test(trainloader, progress, 'train')
            if validation_split:
                self.test(testloader, progress, 'validation')
            epochs_plot.append(epoch)
            accuracy_plot.append(accuracy)
            loss_plot.append(current_loss)
            self.save(epoch, self.model, self.optimizer, current_loss,
                      f'{save_name}-{epoch}.pth')
        return epochs_plot, accuracy_plot, loss_plot

    def test_mnist(self):
        test_data = torchvision.datasets.MNIST('mnist',
                                               False,
                                               download=True,
                                               transform=transforms.Compose(
                                                   self.transform2))
        testloader = torch.utils.data.DataLoader(test_data,
                                                 batch_size=self.batch_size,
                                                 num_workers=2)
        progress = tqdm.tqdm(total=len(testloader.dataset),
                             position=0,
                             leave=True)
        self.test(testloader, progress, 'test')
コード例 #3
0
class Model:
    def __init__(self, local_rank=-1):
        self.flownet = IFNet()
        self.contextnet = ContextNet()
        self.fusionnet = FusionNet()
        self.device()
        self.optimG = AdamW(itertools.chain(self.flownet.parameters(),
                                            self.contextnet.parameters(),
                                            self.fusionnet.parameters()),
                            lr=1e-6,
                            weight_decay=1e-5)
        self.schedulerG = optim.lr_scheduler.CyclicLR(self.optimG,
                                                      base_lr=1e-6,
                                                      max_lr=1e-3,
                                                      step_size_up=8000,
                                                      cycle_momentum=False)
        self.epe = EPE()
        self.ter = Ternary()
        self.sobel = SOBEL()
        if local_rank != -1:
            self.flownet = DDP(self.flownet,
                               device_ids=[local_rank],
                               output_device=local_rank)
            self.contextnet = DDP(self.contextnet,
                                  device_ids=[local_rank],
                                  output_device=local_rank)
            self.fusionnet = DDP(self.fusionnet,
                                 device_ids=[local_rank],
                                 output_device=local_rank)

    def train(self):
        self.flownet.train()
        self.contextnet.train()
        self.fusionnet.train()

    def eval(self):
        self.flownet.eval()
        self.contextnet.eval()
        self.fusionnet.eval()

    def device(self):
        self.flownet.to(device)
        self.contextnet.to(device)
        self.fusionnet.to(device)

    def load_model(self, path, rank):
        def convert(param):
            if rank == -1:
                return {
                    k.replace("module.", ""): v
                    for k, v in param.items() if "module." in k
                }
            else:
                return param

        if rank <= 0:
            self.flownet.load_state_dict(
                convert(
                    torch.load('{}/flownet.pkl'.format(path),
                               map_location=device)))
            self.contextnet.load_state_dict(
                convert(
                    torch.load('{}/contextnet.pkl'.format(path),
                               map_location=device)))
            self.fusionnet.load_state_dict(
                convert(
                    torch.load('{}/unet.pkl'.format(path),
                               map_location=device)))

    def save_model(self, path, rank):
        if rank == 0:
            torch.save(self.flownet.state_dict(),
                       '{}/flownet.pkl'.format(path))
            torch.save(self.contextnet.state_dict(),
                       '{}/contextnet.pkl'.format(path))
            torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))

    def predict(self, imgs, flow, training=True, flow_gt=None):
        img0 = imgs[:, :3]
        img1 = imgs[:, 3:]
        c0 = self.contextnet(img0, flow)
        c1 = self.contextnet(img1, -flow)
        flow = F.interpolate(
            flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
        refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
            img0, img1, flow, c0, c1, flow_gt)
        res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
        mask = torch.sigmoid(refine_output[:, 3:4])
        merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
        pred = merged_img + res
        pred = torch.clamp(pred, 0, 1)
        if training:
            return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
        else:
            return pred

    def inference(self, img0, img1, scale=1.0):
        imgs = torch.cat((img0, img1), 1)
        flow, _ = self.flownet(imgs, scale)
        return self.predict(imgs, flow, training=False)

    def update(self,
               imgs,
               gt,
               learning_rate=0,
               mul=1,
               training=True,
               flow_gt=None):
        for param_group in self.optimG.param_groups:
            param_group['lr'] = learning_rate
        if training:
            self.train()
        else:
            self.eval()
        flow, flow_list = self.flownet(imgs)
        pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
            imgs, flow, flow_gt=flow_gt)
        loss_ter = self.ter(pred, gt).mean()
        if training:
            with torch.no_grad():
                loss_flow = torch.abs(warped_img0_gt - gt).mean()
                loss_mask = torch.abs(merged_img - gt).sum(
                    1, True).float().detach()
                loss_mask = F.interpolate(loss_mask,
                                          scale_factor=0.5,
                                          mode="bilinear",
                                          align_corners=False).detach()
                flow_gt = (F.interpolate(flow_gt,
                                         scale_factor=0.5,
                                         mode="bilinear",
                                         align_corners=False) * 0.5).detach()
            loss_cons = 0
            for i in range(3):
                loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1)
                loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1)
            loss_cons = loss_cons.mean() * 0.01
        else:
            loss_cons = torch.tensor([0])
            loss_flow = torch.abs(warped_img0 - gt).mean()
            loss_mask = 1
        loss_l1 = (((pred - gt)**2 + 1e-6)**0.5).mean()
        if training:
            self.optimG.zero_grad()
            loss_G = loss_l1 + loss_cons + loss_ter
            loss_G.backward()
            self.optimG.step()
        return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
コード例 #4
0
class StudentTrainerUltimate:
    def __init__(self, train_loader, val_loader, val_loader_unit_batch,
                 noise_size, student_hidden_size, student_num_layers, epochs,
                 start_lr, teacher_generator, teacher_mimic_layer,
                 student_mimic_layer):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.val_loader_unit_batch = val_loader_unit_batch

        self.student_num_layers = student_num_layers

        self.teacher_mimic_layer = teacher_mimic_layer
        self.student_mimic_layer = student_mimic_layer

        self.teacher_generator = teacher_generator
        self.teacher_generator.dnn[teacher_mimic_layer].register_forward_hook(
            save_output_of_layer)

        self.student_generator = StudentGenerator(noise_size,
                                                  student_hidden_size,
                                                  student_num_layers,
                                                  128).to(device)

        self.student_generator.dnn[student_mimic_layer].register_forward_hook(
            save_output_of_layer)

        self.start_lr = start_lr

        print(self.teacher_generator)
        print(self.student_generator)

        self.mimic_epochs = 40
        self.train_epochs = epochs

    def train(self,
              metric_every,
              savename,
              lam=1.0,
              sigma=0.001,
              val_size=None):
        if not val_size:
            val_size = len(self.val_loader)

        if os.path.isfile(f"results/{savename}.txt"):
            os.remove(f"results/{savename}.txt")

        self.optimizer = AdamW(self.student_generator.parameters(),
                               lr=self.start_lr)
        lr_track = np.logspace(0, -0.01, num=self.mimic_epochs)
        lr_lambda = lambda x: lr_track[x]
        self.sheduler = LambdaLR(self.optimizer, lr_lambda)
        for epoch in range(1, self.mimic_epochs + 1):
            print(f"epoch {epoch} (mimic)")

            avg_mimic_loss = 0

            for (real, noised, w) in tqdm(self.train_loader):
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                mimic = self.student_generator.mimic(
                    self.student_generator.dnn[self.student_mimic_layer].saved)

                mimic_loss = (
                    (mimic -
                     self.teacher_generator.dnn[self.teacher_mimic_layer].saved
                     )**2).mean()

                avg_mimic_loss += mimic_loss

                loss = mimic_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if epoch != self.mimic_epochs:
                self.sheduler.step()
            avg_mimic_loss = avg_mimic_loss / len(self.train_loader)
            print(f"mimic loss: {avg_mimic_loss:.6f}")

        self.optimizer = AdamW(self.student_generator.parameters(),
                               lr=self.start_lr)
        lr_track = np.logspace(0, -0.01, num=self.train_epochs)
        lr_lambda = lambda x: lr_track[x]
        self.sheduler = LambdaLR(self.optimizer, lr_lambda)
        for epoch in range(1, self.train_epochs + 1):
            print(f"epoch {epoch} (train)")

            avg_main_loss = 0
            avg_relation_loss = 0

            for (real, noised, w) in tqdm(self.train_loader):
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                batch_size = noised.size(0)
                half_batch = batch_size // 2

                disturbance = torch.tensor(
                    np.random.normal(loc=0.0,
                                     scale=sigma,
                                     size=teacher_generated.size())).to(device)

                teacher_generated = (1 + disturbance) * teacher_generated

                main_loss = ((teacher_generated - student_generated)**2).mean()

                # Relaional loss

                s_up = student_generated[:half_batch]
                s_down = student_generated[half_batch:half_batch * 2]

                t_up = teacher_generated[:half_batch]
                t_down = teacher_generated[half_batch:half_batch * 2]

                t_distances = (((t_up - t_down)**2).sum(axis=1))**(1 / 2)
                s_distances = (((s_up - s_down)**2).sum(axis=1))**(1 / 2)

                mu = t_distances.mean()

                t_potentials = t_distances / mu
                s_potentials = s_distances / mu

                relation_loss = ((t_potentials - s_potentials)**2).mean()
                #

                avg_main_loss += main_loss
                avg_relation_loss += relation_loss

                loss = main_loss + lam * relation_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if epoch != self.train_epochs:
                self.sheduler.step()
            avg_main_loss = avg_main_loss / len(self.train_loader)
            avg_relation_loss = avg_relation_loss / len(self.train_loader)
            print(
                f"main loss: {avg_main_loss:.6f} relation loss: {avg_relation_loss:.6f}"
            )

            if epoch % metric_every == 0:
                self.validate(val_size, epoch, savename, avg_main_loss,
                              avg_relation_loss)

    def validate(self, val_size, epoch, savename, avg_main_loss,
                 avg_relation_loss):
        with torch.no_grad():
            n = 0
            start = time.time()
            for (_, noised, _) in self.val_loader_unit_batch:
                _ = self.teacher_generator(noised)
                n += 1
                if n == 1000:
                    break
            teacher_time = (time.time() - start) / 1000

            n = 0
            start = time.time()
            for (_, noised, _) in self.val_loader_unit_batch:
                _ = self.student_generator(noised)
                n += 1
                if n == 1000:
                    break
            student_time = (time.time() - start) / 1000

            teacher_ms = teacher_time * 1000
            student_ms = student_time * 1000

            print(
                f"avg teacher: {teacher_ms:.3f}ms, avg student: {student_ms:.3f}ms"
            )

            real_batches = []
            teacher_gen_batches = []
            student_gen_batches = []
            w_batches = []
            step = 0
            for (real, noised, w) in self.val_loader:
                step += 1
                if step > val_size:
                    break
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                real_batches.append(real.detach().cpu())
                teacher_gen_batches.append(teacher_generated.detach().cpu())
                student_gen_batches.append(student_generated.detach().cpu())
                w_batches.append(w.detach().cpu())

        info = {
            "teacher time, ms": teacher_ms,
            "student time, ms": student_ms,
            "main loss": avg_main_loss,
            "relation loss": avg_relation_loss,
            "epoch": epoch
        }

        plot_metrics(torch.cat(real_batches, dim=0),
                     torch.cat(teacher_gen_batches, dim=0),
                     torch.cat(student_gen_batches, dim=0),
                     torch.cat(w_batches, dim=0), info, savename)
コード例 #5
0
def train(config, args, device, logger):
    # Dataset
    if config['model']['share_architecture'] in ['ocnli']:
        train_dataset = OCNLIDataset(
            config['dataset'],
            split='trainval' if args.no_validate else 'train',
            overfit=args.overfit,
            tensor_type='np'
        )
        if not args.no_validate:
            val_dataset = OCNLIDataset(
                config['dataset'],
                split='val',
                overfit=args.overfit,
                tensor_type='np'
            )
        task_name2int = {'ocnli': TASK_NAME2INT['ocnli']}
    elif config['model']['share_architecture'] == 'ocemotion':
        train_dataset = OCEMOTIONDataset(
            config['dataset'],
            split='trainval' if args.no_validate else 'train',
            overfit=args.overfit,
            tensor_type='np'
        )
        if not args.no_validate:
            val_dataset = OCEMOTIONDataset(
                config['dataset'],
                split='val',
                overfit=args.overfit,
                tensor_type='np'
            )
        task_name2int = {'ocemotion': TASK_NAME2INT['ocemotion']}
    elif config['model']['share_architecture'] == 'tnews':
        train_dataset = TNEWSDataset(
            config['dataset'],
            split='trainval' if args.no_validate else 'train',
            overfit=args.overfit,
            tensor_type='np'
        )
        if not args.no_validate:
            val_dataset = TNEWSDataset(
                config['dataset'],
                split='val',
                overfit=args.overfit,
                tensor_type='np'
            )
        task_name2int = {'tnews': TASK_NAME2INT['tnews']}
    else:
        train_dataset = NLPCJointDataset(
            config['dataset'],
            split='trainval' if args.no_validate else 'train',
            overfit=args.overfit,
            tensor_type='np'
        )
        if not args.no_validate:
            val_dataset = NLPCJointDataset(
                config['dataset'],
                split='val',
                overfit=args.overfit,
                tensor_type='np'
            )
        task_name2int = {
            'ocnli': TASK_NAME2INT['ocnli'],
            'ocemotion': TASK_NAME2INT['ocemotion'],
            'tnews': TASK_NAME2INT['tnews']
        }

    logger.info(
        'Training set number of samples: {}'.format(len(train_dataset))
    )
    if not args.no_validate:
        logger.info(
            'Validation set number of samples: {}'.format(len(val_dataset))
        )

    assert(
        config['solver']['batch_size']
        % config['solver']['accumulation_steps'] == 0
    )
    actual_batch_size = (
        config['solver']['batch_size']
        // config['solver']['accumulation_steps']
    )
    logger.info('Acture batch size: {}'.format(actual_batch_size))
    logger.info(
        'Gradient accumulation steps: {}'
        .format(config['solver']['accumulation_steps'])
    )
    logger.info(
        'Effective batch size: {}'.format(config['solver']['batch_size'])
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=actual_batch_size,
        shuffle=True,
        num_workers=args.cpu_workers,
        collate_fn=collate_fn_with_padding
    )
    if not args.no_validate:
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=actual_batch_size * 4,
            shuffle=False,
            num_workers=args.cpu_workers,
            collate_fn=collate_fn_with_padding
        )

    # Model
    model = NLPCModel(config['model']).to(device)
    if -1 not in args.gpu_ids:
        model = nn.DataParallel(model, args.gpu_ids)

    if args.load_pthpath != "":
        model_state_dict, _ = load_checkpoint(args.load_pthpath)
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(model_state_dict)
        else:
            model.load_state_dict(model_state_dict)
        logger.info(
            'Loaded model checkpoint from {}.'.format(args.load_pthpath)
        )

    # loss
    criterion = NLPCLoss(config['model'], task_name2int, 'train', device)
    if not args.no_validate:
        val_criterion = NLPCLoss(config['model'], task_name2int, 'val', device)

    # Weight decay
    if 'no_decay' in config['solver'].keys():
        no_decay = config['solver']['no_decay']
    else:
        no_decay = []

    transformer_params = [
        item for item in list(model.named_parameters())
        if 'transformer' in item[0]
    ]
    not_transformer_params = [
        item for item in list(model.named_parameters())
        if 'transformer' not in item[0]
    ]

    grouped_parameters = [
        # non-transformer and need decay
        {
            'params': [
                p for n, p in not_transformer_params
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay': config['solver']['weight_decay'],
            "lr": config['solver']['initial_lr']
        },
        # transformer and need decay
        {
            'params': [
                p for n, p in transformer_params
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay': config['solver']['transformer_weight_decay'],
            'lr': (
                config['solver']['transformer_initial_lr']
                if 'transformer_initial_lr' in config['solver']
                else config['solver']['initial_lr']
            )
        },
        # non-transformer and need not decay
        {
            'params': [
                p for n, p in not_transformer_params
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay': 0.0,
            'lr': config['solver']['initial_lr']
        },
        # transformer and need not decay
        {
            'params': [
                p for n, p in transformer_params
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay': 0.0,
            'lr': (
                config['solver']['transformer_initial_lr']
                if 'transformer_initial_lr' in config['solver']
                else config['solver']['initial_lr']
            )
        }
    ]
    if 'task_weights' in config['model'] \
       and config['model']['task_weights'] == 'uct':
        grouped_parameters.append(
            {
                'params': criterion.parameters(),
                'weight_decay': 0.0,
                'lr': (
                    config['solver']['uct_initial_lr']
                    if 'uct_initial_lr' in config['solver']
                    else config['solver']['initial_lr']
                )
            }
        )

    # Optimizer
    if config['solver']['optimizer'] == 'AdamW':
        optimizer = AdamW(
            grouped_parameters,
            lr=config["solver"]["initial_lr"],
            weight_decay=config['solver']['weight_decay']
        )
    else:
        raise ValueError(
            'optimizer {} not support now.'
            .format(config['solver']['optimizer'])
        )

    # Learning rate schedule
    total_steps = (
        math.ceil(
            len(train_dataloader) / config['solver']['accumulation_steps']
        ) * config['solver']['num_epochs']
        if 'num_epochs' in config['solver']
        else config['solver']['total_steps']
    )
    warmup_steps = (
        math.ceil(total_steps * config['solver']['warmup_fraction'])
        if 'warmup_fraction' in config['solver']
        else config['solver']['warmup_steps']
    )
    validation_steps = (
        config['solver']['validation_steps']
        if 'validation_steps' in config['solver']
        else math.ceil(
            len(train_dataloader) / config['solver']['accumulation_steps']
        )
    ) if not args.no_validate else total_steps

    logger.info('Total steps: {}'.format(total_steps))
    logger.info('Warmup_steps: {}'.format(warmup_steps))
    if not args.no_validate:
        logger.info('Validation steps: {}'.format(validation_steps))

    if config['solver']['lr_schedule'] == 'warmup_linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, warmup_steps, total_steps
        )
    elif config['solver']['lr_schedule'] == 'warmup_cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, warmup_steps, total_steps
        )
    elif config['solver']['lr_schedule'] == 'warmup_cosine_with_hard_restarts':
        num_cycles = config['solver']['num_cycles']
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer, warmup_steps, total_steps, num_cycles=num_cycles
        )
    else:
        raise ValueError(
            'Learning rate schedule {} not support not.'
            .format(config['solver']['lr_schedule'])
        )

    # Setup before training
    summary_writer = SummaryWriter(logdir=args.save_dirpath)
    checkpoint_manager = CheckpointManager(
        model, optimizer, args.save_dirpath, overwrite=True, config=config
    )
    accumulation_steps = config['solver']['accumulation_steps']
    forward_steps = 0
    optimizer_steps = 0
    loss = []
    if not args.no_validate:
        best_score = float('-inf')

    # Evaluate before training if loaded pretrained model
    if not args.no_validate and args.load_pthpath != "":
        model.eval()
        val_losses, val_report = evaluate(
            model, val_dataloader, val_criterion, device, task_name2int
        )
        val_score = val_report['competition_score']
        logger.info('Step {} evaluate result:'.format(optimizer_steps))
        for k, v in val_losses.items():
            logger.info('    {} = {:.6f}'.format(k, v))
            if k == 'val_loss':
                summary_writer.add_scalar(
                    "val/loss", v, global_step=optimizer_steps
                )
            else:
                summary_writer.add_scalar(
                    "val/" + k, v, global_step=optimizer_steps
                )
        for k, v in val_report.items():
            logger.info('    {} = {:.6f}'.format(k, v))
            summary_writer.add_scalar(
                "val/" + k, v, global_step=optimizer_steps
            )

    # Training loop
    model.train()
    train_iterator = iter(train_dataloader)
    for _ in range(int(math.ceil(total_steps / validation_steps))):
        for _ in tqdm(range(validation_steps * accumulation_steps)):
            try:
                batch = next(train_iterator)
            except StopIteration:
                train_iterator = iter(train_dataloader)
                if args.overfit:
                    break
                else:
                    batch = next(train_iterator)
            for key in batch:
                batch[key] = batch[key].to(device)

            batch_output = model(batch)
            batch_loss_output = criterion(
                batch_output, batch['target'], batch['task_type_id']
            )

            if isinstance(batch_loss_output, torch.Tensor):
                batch_loss = batch_loss_output / accumulation_steps
                batch_loss.backward()
                loss.append(batch_loss.detach().cpu().numpy())
            elif (
                isinstance(batch_loss_output, dict)
                and 'task_weights' in config['model']
                and config['model']['task_weights'] in ['uct', 'dtp']
            ):
                batch_loss = batch_loss_output['loss'] / accumulation_steps
                batch_loss.backward()
                loss.append(batch_loss.detach().cpu().numpy())
            else:
                raise ValueError()

            forward_steps += 1

            if forward_steps % accumulation_steps == 0:
                optimizer_steps += 1

                loss = np.sum(loss)
                summary_writer.add_scalar(
                    "train/loss", loss, global_step=optimizer_steps
                )
                loss = []

                if isinstance(batch_loss_output, dict) \
                   and 'task_weights' in config['model'] \
                   and config['model']['task_weights'] == 'uct':
                    for task_name in task_name2int:
                        summary_writer.add_scalar(
                            "train/weight_" + task_name,
                            batch_loss_output["weight_" + task_name],
                            global_step=optimizer_steps
                        )

                if isinstance(batch_loss_output, dict) \
                   and 'task_weights' in config['model'] \
                   and config['model']['task_weights'] == 'dtp':
                    for task_name in task_name2int:
                        summary_writer.add_scalar(
                            "train/running_kpi_" + task_name,
                            batch_loss_output["running_kpi_" + task_name],
                            global_step=optimizer_steps
                        )

                summary_writer.add_scalar(
                    "train/lr", optimizer.param_groups[0]["lr"],
                    global_step=optimizer_steps
                )
                summary_writer.add_scalar(
                    "train/transformer_lr", optimizer.param_groups[1]["lr"],
                    global_step=optimizer_steps
                )
                if 'task_weights' in config['model'] \
                   and config['model']['task_weights'] == 'uct':
                    summary_writer.add_scalar(
                        "train/uct_lr", optimizer.param_groups[-1]["lr"],
                        global_step=optimizer_steps
                    )

                if config['solver']['max_grad_norm'] > 0:
                    clip_grad_norm_(
                        model.parameters(),
                        config['solver']['max_grad_norm']
                    )
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                torch.cuda.empty_cache()

                if optimizer_steps >= total_steps:
                    break

        # Evaluate on validation set.
        if not args.no_validate:
            model.eval()
            val_losses, val_report = evaluate(
                model, val_dataloader, val_criterion, device, task_name2int
            )
            val_score = val_report['competition_score']
            logger.info('Step {} evaluate result:'.format(optimizer_steps))
            for k, v in val_losses.items():
                logger.info('    {} = {:.6f}'.format(k, v))
                if k == 'val_loss':
                    summary_writer.add_scalar(
                        "val/loss", v, global_step=optimizer_steps
                    )
                else:
                    summary_writer.add_scalar(
                        "val/" + k, v, global_step=optimizer_steps
                    )
            for k, v in val_report.items():
                logger.info('    {} = {:.6f}'.format(k, v))
                summary_writer.add_scalar(
                    "val/" + k, v, global_step=optimizer_steps
                )

            if val_score > best_score:
                checkpoint_manager.step()
                logger.info(
                    '    Validation best score update from {:.6f} to {:.6f}. '
                    'Saved checkpoint to {}'.format(
                        best_score, val_score, args.save_dirpath + 'checkpoint.pth'
                    )
                )
                best_score = val_score
            else:
                logger.info(
                    '    Validation best score not updated since {:.6f}. '
                    'No checkpoint saved.'.format(best_score)
                )
            model.train()
            torch.cuda.empty_cache()
            summary_writer.flush()

    # Save the final model if no validate
    if args.no_validate:
        checkpoint_manager.step()
        logger.info(
            'Saved final checkpoint to {}'.format(
                args.save_dirpath + 'checkpoint.pth'
            )
        )

    summary_writer.close()
コード例 #6
0
    attention_mask = batch[1].to(device)
    labels = batch[2].to(device)
    outputs, bert, fc3 = model(input_ids,
                               attention_mask=attention_mask,
                               labels=labels,
                               task='oc')
    oc_loss = loss_fct(outputs.view(-1, 3), labels.view(-1))
    oc_f1 = f1_score(labels.cpu().numpy(),
                     outputs.argmax(dim=1).cpu().numpy(),
                     average='macro')

    loss = oce_loss + news_loss + oc_loss

    train_loss += loss.item()
    loss.backward()
    optim.step()

    f1 = (oce_f1 + news_f1 + oc_f1) / 3

    train_f1 += f1

    pbar.update()
    pbar.set_description(
        f'oce_loss:{round(oce_loss.item(), 4)}, oce_f1:{round(oce_f1, 4)},'
        f'news_loss:{round(news_loss.item(), 4)}, news_f1:{round(news_f1, 4)},'
        f'oc_loss:{round(oc_loss.item(), 4)}, oc_f1:{round(oc_f1, 4)},'
        f'loss:{round(loss.item(), 4)}, f1:{round(f1, 4)}')

    if i != 0 and i % 50 == 0:
        oce_loss_v, oce_f1_v = valid_func(oce_valid_loader, 'oce')
        news_loss_v, news_f1_v = valid_func(news_valid_loader, 'news')
コード例 #7
0
class TD3Agent(AgentBase):
    """
    Twin Delayed Deep Deterministic (TD3) Policy Gradient.

    In short, it's a slightly modified/improved version of the DDPG. Compared to the DDPG in this package,
    which uses Guassian noise, this TD3 uses Ornstein–Uhlenbeck process as the noise.
    """

    name = "TD3"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 actor_lr: float = 1e-3,
                 critic_lr: float = 1e-3,
                 noise_scale: float = 0.2,
                 noise_sigma: float = 0.1,
                 device=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device if device is not None else DEVICE

        # Reason sequence initiation.
        self.state_size = state_size
        self.action_size = action_size

        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        self.actor = ActorBody(state_size,
                               action_size,
                               hidden_layers=hidden_layers).to(self.device)
        self.critic = DoubleCritic(state_size,
                                   action_size,
                                   CriticBody,
                                   hidden_layers=hidden_layers).to(self.device)
        self.target_actor = ActorBody(state_size,
                                      action_size,
                                      hidden_layers=hidden_layers).to(
                                          self.device)
        self.target_critic = DoubleCritic(state_size,
                                          action_size,
                                          CriticBody,
                                          hidden_layers=hidden_layers).to(
                                              self.device)

        # Noise sequence initiation
        # self.noise = GaussianNoise(shape=(action_size,), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=device)
        self.noise = OUProcess(shape=action_size,
                               scale=noise_scale,
                               sigma=noise_sigma,
                               device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_optimizer = AdamW(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = AdamW(self.critic.parameters(), lr=critic_lr)
        self.max_grad_norm_actor: float = float(
            kwargs.get("max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic: float = float(
            kwargs.get("max_grad_norm_critic", 10.0))
        self.action_min = float(self._register_param(kwargs, 'action_min',
                                                     -1.))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1.))
        self.action_scale = float(
            self._register_param(kwargs, 'action_scale', 1.))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e5)))
        self.buffer = ReplayBuffer(self.batch_size, self.buffer_size)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.update_policy_freq = int(
            self._register_param(kwargs, 'update_policy_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.noise_reset_freq = int(
            self._register_param(kwargs, 'noise_reset_freq', 10000))

        # Breath, my child.
        self.reset_agent()
        self.iteration = 0
        self._loss_actor = 0.
        self._loss_critic = 0.

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.critic.reset_parameters()
        self.target_actor.reset_parameters()
        self.target_critic.reset_parameters()

    def act(self,
            state,
            epsilon: float = 0.0,
            training_mode=True) -> List[float]:
        """
        Agent acting on observations.

        When the training_mode is True (default) a noise is added to each action.
        """
        # Epsilon greedy
        if self._rng.random() < epsilon:
            rnd_actions = torch.rand(self.action_size) * (
                self.action_max - self.action_min) - self.action_min
            return rnd_actions.tolist()

        with torch.no_grad():
            state = to_tensor(state).float().to(self.device)
            action = self.actor(state)
            if training_mode:
                action += self.noise.sample()
            return (self.action_scale * torch.clamp(action, self.action_min,
                                                    self.action_max)).tolist()

    def target_act(self, staten, noise: float = 0.0):
        with torch.no_grad():
            staten = to_tensor(staten).float().to(self.device)
            action = self.target_actor(staten) + noise * self.noise.sample()
            return torch.clamp(action, self.action_min,
                               self.action_max).cpu().numpy().astype(
                                   np.float32)

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1
        self.buffer.add(state=state,
                        action=action,
                        reward=reward,
                        next_state=next_state,
                        done=done)

        if (self.iteration % self.noise_reset_freq) == 0:
            self.noise.reset_states()

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) <= self.batch_size:
            return

        if not (self.iteration % self.update_freq) or not (
                self.iteration % self.update_policy_freq):
            for _ in range(self.number_updates):
                # Note: Inside this there's a delayed policy update.
                #       Every `update_policy_freq` it will learn `number_updates` times.
                self.learn(self.buffer.sample())

    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)

        if (self.iteration % self.update_freq) == 0:
            self._update_value_function(states, actions, rewards, next_states,
                                        dones)

        if (self.iteration % self.update_policy_freq) == 0:
            self._update_policy(states)

            soft_update(self.target_actor, self.actor, self.tau)
            soft_update(self.target_critic, self.critic, self.tau)

    def _update_value_function(self, states, actions, rewards, next_states,
                               dones):
        # critic loss
        next_actions = self.target_actor.act(next_states)
        Q_target_next = torch.min(
            *self.target_critic.act(next_states, next_actions))
        Q_target = rewards + (self.gamma * Q_target_next * (1 - dones))
        Q1_expected, Q2_expected = self.critic(states, actions)
        loss_critic = mse_loss(Q1_expected, Q_target) + mse_loss(
            Q2_expected, Q_target)

        # Minimize the loss
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(),
                                 self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

    def _update_policy(self, states):
        # Compute actor loss
        pred_actions = self.actor(states)
        loss_actor = -self.critic(states, pred_actions)[0].mean()
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = loss_actor.item()

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
コード例 #8
0
def trainner(model, play_history, train_config: dict):
    model.train()
    train_history, valid_history, split_point = play_history.get_train_valid_data(rate=train_config['traindata_rate'])
    train_dataset = AlphaDataset(play_histry=train_history)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=train_config['batch_size'],
        shuffle=True,
        num_workers=train_config['num_workers'],
        collate_fn=AlphaDataset.collate_fn,
        pin_memory=True,
    )

    if valid_history is not None:
        valid_dataset = AlphaDataset(play_histry=valid_history)
        valid_loader = DataLoader(
            dataset=valid_dataset,
            batch_size=train_config['batch_size'] * 2,
            shuffle=False,
            num_workers=train_config['num_workers'],
            collate_fn=AlphaDataset.collate_fn,
            pin_memory=True,
        )
    else:
        valid_loader = None

    optimizer = AdamW(params=model.parameters(), lr=train_config['base_lr'])
    scheduler = lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=train_config['epochs'], eta_min=train_config['min_lr']
    )

    for epoch in range(train_config['epochs']):
        train_value_mean = Avg()
        train_policy_mean = Avg()
        for states, actions, winners in train_loader:
            optimizer.zero_grad()

            value, policy = model(states)
            value_loss = functional.mse_loss(input=value.view(-1), target=winners)
            policy_loss = functional.cross_entropy(input=policy, target=actions)

            loss = train_config['value_loss_weight'] * value_loss + train_config['policy_loss_weight'] * policy_loss

            loss.backward()
            optimizer.step()

            train_value_mean.update(value=value_loss.item())
            train_policy_mean.update(value=policy_loss.item())

        scheduler.step()

        if valid_loader is not None:
            valid_value_mean = Avg()
            valid_policy_mean = Avg()
            for states, actions, winners in valid_loader:
                with torch.no_grad():
                    value, policy = model(states)
                    value_loss = functional.mse_loss(input=value.view(-1), target=winners)
                    policy_loss = functional.cross_entropy(input=policy, target=actions)

                value_loss = value_loss.item()
                policy_loss = policy_loss.item()

                valid_value_mean.update(value=value_loss)
                valid_policy_mean.update(value=policy_loss)

        msg = f'epochs: [{epoch}/{train_config["epochs"]}]'
        msg += f' - train value loss: {train_value_mean():.6f} - train policy loss: {train_policy_mean():.6f}'
        if valid_loader is not None:
            msg += f' - valid value loss: {valid_value_mean():.6f} - valid policy loss: {valid_policy_mean():.6f}'
        logging.info(msg=msg)
    model.eval()
コード例 #9
0
def main():
    # my dice shows 777 only. period.
    random.seed(EXPCONF.seed)
    np.random.seed(EXPCONF.seed)
    torch.manual_seed(EXPCONF.seed)
    torch.cuda.manual_seed_all(EXPCONF.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    trainloader, vocab, _trainds = get_loader(EXPCONF, getdev=False)
    devloader, _, _devds = get_loader(EXPCONF, getdev=True)

    assert len(trainloader) > 0, f"trainloader is empty!"
    assert len(devloader) > 0, f"devloader is empty!"

    # this is disgraceful.... but just specify things below
    albertconf = AlbertConfig.from_pretrained(
        f'albert-{EXPCONF.albert_scale}-v2')
    if EXPCONF.smaller:  #originally used 4H for FFN but for memory issue, use 1H for FFN
        albertconf.hidden_size = EXPCONF.hidden_size
        albertconf.num_hidden_layers = EXPCONF.num_hidden_layers
        albertconf.num_attention_heads = EXPCONF.num_attention_heads

        albertconf.intermediate_size = albertconf.hidden_size

    albertconf.vocab_size = len(vocab.itos)
    albertconf.bos_token_id = vocab.stoi['BOS']
    albertconf.eos_token_id = vocab.stoi['EOS']
    albertconf.pad_token_id = vocab.stoi['PAD']
    albertconf.max_position_embeddings = 40

    model = AlbertForPreTraining(albertconf).to(device)

    # huggingface example is doing this for language modeling...
    # https://github.com/huggingface/transformers/blob/v2.6.0/examples/run_language_modeling.py
    no_decay = ['bias', "LayerNorm.weight"]
    grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            EXPCONF.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(grouped_parameters,
                      lr=EXPCONF.lr)  # otherwise, use default
    getsch = get_cosine_schedule_with_warmup if EXPCONF.scheduler == 'cosine' else get_linear_schedule_with_warmup
    scheduler = getsch(optimizer, EXPCONF.warmups,
                       EXPCONF.numep * len(trainloader))

    global_step = 0
    L = len(trainloader)
    bsz = len(trainloader[0])

    for ep in tqdm(range(1, EXPCONF.numep + 1), desc="epoch progress"):
        lossep_mlm = 0
        lossep_pp = 0
        accep_pp = 0
        model.train()
        for i, (b, l, datasetids) in enumerate(
                tqdm(trainloader, desc="iterations progress"), 1):
            '''
            b.input_ids/token_type_ids/attention_mask .shape ==  (bsz, seqmaxlen,)
            b.l.shape == (bsz,)

            ## bert families, when they do MLM with NSP (or other similar sentence based tasks,)
            ## they just uses masked input for their sentence representation encoding, not the unmasked ones
            ## it could be considered as some kind of dropout but at first it looked quite irregular to me.

            ## --> referred to transformers/examples/run_language_modeling.py (v2.1.0)
            ## --> modeling_albert.py ( class AlbertModel.forward() )
            '''

            outputs = model(**b, sentence_order_label=l, return_dict=True)
            global_step += 1

            vsz = outputs.prediction_logits.shape[-1]

            lossmlm = F.cross_entropy(
                outputs.prediction_logits.view(-1, vsz).contiguous(),
                b['labels'].view(-1))
            losspp = F.cross_entropy(outputs.sop_logits, l)
            lossppval = losspp.item()
            acc = accuracy(outputs.sop_logits.clone().detach(), l)

            if EXPCONF.alpha_pp == 1 and not EXPCONF.alpha_warmup:
                outputs.loss.backward()
            else:
                del outputs.loss
                torch.cuda.empty_cache()

                losspp *= EXPCONF.alpha_pp

                if EXPCONF.alpha_warmup:
                    grow = min(global_step / EXPCONF.warmups, 1.0)
                    losspp *= grow

                loss = lossmlm + losspp
                loss.backward()

            wandb.log({
                'step':
                (i + ep * L) * bsz if EXPCONF.see_bsz_effect else global_step,
                'train_step/learning_rate':
                get_lr_from_optim(optimizer),
                'train_step/alpha_pp':
                EXPCONF.alpha_pp * (grow if EXPCONF.alpha_warmup else 1),
                'train_step/mlm_loss':
                lossmlm.item(),
                'train_step/pp_loss':
                lossppval,
                'train_step/pp_acc':
                acc,
            })

            optimizer.step()
            scheduler.step()
            model.zero_grad()

            lossep_mlm += lossmlm.item()
            lossep_pp += lossppval
            accep_pp += acc

        lossep_mlm /= L
        lossep_pp /= L
        accep_pp /= L

        wandb.log({
            'step': ep,
            'train_ep/mlm_loss': lossep_mlm,
            'train_ep/pp_loss': lossep_pp,
            'train_ep/pp_acc': accep_pp,
        })
        print(f"ep:{ep}: losspp = {lossep_pp}, lossmlm={lossep_mlm}")
        devmlm_loss, devpp_loss, devpp_acc = evaldev(EXPCONF, model, devloader,
                                                     ep)
        if devpp_acc > EXPCONF.savethld:
            savemodel(EXPCONF,
                      model,
                      vocab,
                      ep,
                      mlm=devmlm_loss,
                      pp=devpp_loss,
                      acc=devpp_acc)
    return None
コード例 #10
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--epoch", type=int, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--emb_file", type=str, required=True)
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--save_dir", type=str, required=True)
    parser.add_argument("--train_file", type=str, required=True)
    parser.add_argument("--log_file", type=str, required=False)
    parser.add_argument("--ratio", type=str, required=True)
    parser.add_argument("--vocab_size", type=int, required=True)
    parser.add_argument("--emb_size", type=int, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--batch_size", type=int, required=True)
    parser.add_argument("--max_length", type=int, required=True)
    parser.add_argument("--max_grad_norm", type=int, required=True)

    args = parser.parse_args()

    split_ratio = [float(val) for val in args.ratio.split(",")]

    has_cuda = torch.cuda.is_available()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
    DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p"
    logging.basicConfig(filename=args.log_file,
                        level=logging.INFO,
                        format=LOG_FORMAT,
                        datefmt=DATE_FORMAT)

    logging.info("start preparing data")
    data_preprocessor = DataPreprocess()
    emb, word_idx_map = data_preprocessor.build_emb_vocab(args.emb_file)
    data_preprocessor.load(args.train_file, use_mask=False, is_test=False)
    train_dataset, dev_dataset = data_preprocessor.generate_train_dev_dataset(
        ratio=split_ratio)
    train_dataset, dev_dataset = CompDataSet(
        train_dataset,
        word_idx_map,
        max_len=args.max_length,
        emb_size=args.emb_size), CompDataSet(dev_dataset,
                                             word_idx_map,
                                             max_len=args.max_length,
                                             emb_size=args.emb_size)

    train_dataset = DataLoader(train_dataset,
                               batch_size=args.batch_size,
                               shuffle=True)
    dev_dataset = DataLoader(dev_dataset,
                             batch_size=args.batch_size,
                             shuffle=True)

    logging.info("init model")
    start_epoch = 0
    if args.checkpoint:
        model = torch.load(args.checkpoint)
        start_epoch = re.findall("\d+(?=\_\d+.pt)", args.checkpoint)
        start_epoch = int(start_epoch[0]) + 1
    else:
        model = ESIM(args.vocab_size,
                     args.emb_size,
                     emb,
                     max_len=args.max_length)

    optimizer = AdamW(model.parameters(), lr=args.learning_rate)
    criterion = FocalLoss()

    if has_cuda:
        model = model.cuda()

    logging.info("start training")
    neg_auc, pos_auc = validate(model, dev_dataset)
    logging.info(f"pre-train neg_auc {str(neg_auc)} pos_auc {str(pos_auc)}")

    for epoch in range(start_epoch, args.epoch):
        running_loss = 0.0
        for step, data in enumerate(train_dataset):
            model.train()
            start_time = time.time()
            optimizer.zero_grad()

            outputs = model(data["premise"], data["premise_mask"],
                            data["hypothese"], data["hypothese_mask"])
            loss = criterion(outputs["probs"], data["label"])
            loss.backward()

            clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

            end_time = time.time()
            running_loss += loss.item()
            if step % 100 == 99:
                logging.info(
                    f"epoch: {epoch}, step: {step}, time: {end_time - start_time} loss: {running_loss / 100}"
                )
                running_loss = 0
            if step % 500 == 499:
                neg_auc, pos_auc = validate(model, dev_dataset)
                logging.info(
                    f"pre-train neg_auc {str(neg_auc)} pos_auc {str(pos_auc)}")
                torch.save(model, Path(args.save_dir) / f"{epoch}_{step}.pt")
コード例 #11
0
class Trainer(object):
    '''Train network'''
    @classmethod
    def create(cls, args, train='train.pt', validation='validation.pt'):
        enl, dnl = AutoEncoder.get_non_linearity(args.nonlinearity)
        return Trainer(AutoEncoder(encoder_sizes=args.encoder,
                                   encoding_dimension=args.dimension,
                                   encoder_non_linearity=enl,
                                   decoder_non_linearity=dnl,
                                   decoder_sizes=args.decoder),
                       DataLoader(load(join(args.data, train)),
                                  batch_size=args.batch,
                                  shuffle=True,
                                  num_workers=cpu_count()),
                       DataLoader(load(join(args.data, validation)),
                                  batch_size=32,
                                  shuffle=False,
                                  num_workers=cpu_count()),
                       lr=args.lr,
                       weight_decay=args.weight_decay,
                       path=args.data)

    def __init__(self,
                 model,
                 loader,
                 validation_loader,
                 criterion=MSELoss(),
                 lr=0.001,
                 weight_decay=0.01,
                 path='./'):
        super().__init__()
        self.model = model
        self.loader = loader
        self.validation_loader = validation_loader
        self.Losses = [float('inf')]
        self.ValidationLosses = [float('inf')]
        self.criterion = criterion
        self.optimizer = AdamW(model.parameters(),
                               lr=lr,
                               weight_decay=weight_decay)
        self.path = path
        self.lr = lr
        self.weight_decay = weight_decay

    def train(self, N_EPOCHS=25, N_BURN=5, args_dict={}):
        '''
            Adjust weights until overtraining starts.

            The weights are saved each iteration, so the best set of weights will be preserved.
        '''
        for epoch in range(N_EPOCHS):
            self.train_step()
            self.validation_step()
            print(
                f'epoch : {epoch + 1}/{N_EPOCHS}, losses = {self.Losses[-1]:.6f}, {self.ValidationLosses[-1]:.6f}'
            )
            if epoch > N_BURN and self.ValidationLosses[
                    -1] > self.ValidationLosses[-2]:
                return self.ValidationLosses[-2]
            else:
                self.save_model(args_dict)

        return self.ValidationLosses[-1]

    def train_step(self):
        '''
            Compute gradients, adjust weights, and compute training loss
        '''
        loss = 0
        for batch_features, _ in self.loader:
            batch_features = batch_features.view(-1,
                                                 self.model.get_input_length())
            self.optimizer.zero_grad()
            outputs = self.model(
                batch_features.float())  # FIXME - quick hack for #36
            train_loss = self.criterion(
                outputs.float(),
                batch_features.float())  # FIXME - quick hack for #36
            train_loss.backward()
            self.optimizer.step()
            loss += train_loss.item()

        self.Losses.append(loss / len(self.loader))

    def validation_step(self):
        '''
            Computer validation loss
        '''
        loss = 0.0
        with no_grad():
            for i, (batch_features, _) in enumerate(self.validation_loader):
                batch_features = batch_features.view(
                    -1, self.model.get_input_length())
                outputs = self.model(
                    batch_features.float())  #FIXME - quick hack for #36
                validation_loss = self.criterion(
                    outputs,
                    batch_features.float())  #FIXME - quick hack for #36
                loss += validation_loss.item()

        self.ValidationLosses.append(loss / len(self.validation_loader))

    def save_model(self, args_dict):
        '''
            Save current state of model
        '''
        save(
            {
                'model_state_dict': self.model.state_dict(),
                'args_dict': args_dict
            }, join(self.path, self.get_file_name(
                args_dict['dimension'])))  #FIXME - quick hack for #36

    def get_file_name(self, dimension, name='saved', ext='pt'):
        '''
            Used to assign names to files, including hyperparameter values
        '''

        return f'{get_file_name(name,dimension,self.lr,weight_decay=self.weight_decay)}.{ext}'  #FIXME - quick hack for #36
コード例 #12
0
def main():
    # my dice shows 777 only. period.
    random.seed(EXPCONF.seed)
    np.random.seed(EXPCONF.seed)
    torch.manual_seed(EXPCONF.seed)
    torch.cuda.manual_seed_all(EXPCONF.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tempconf = EXPCONF.copy()
    tempconf.datamode = 'test'

    testloader, ___, _____ = get_loader(tempconf)
    trainloader, __, _trainds = get_loader(EXPCONF, getdev=False)
    devloader, _, _devds = get_loader(EXPCONF, getdev=True)

    assert len(trainloader) > 0, f"trainloader is empty!"
    assert len(devloader) > 0, f"devloader is empty!"

    # this is disgraceful.... but just specify things below
    model_weight, vocab, trained_condition = loadmodel_info(EXPCONF)

    albertconf = retrieve_conf(trained_condition, vocab)
    albert = AlbertForPreTraining(albertconf)
    albert.load_state_dict(model_weight)
    albert = albert.to(device)

    global_step = 0
    L = len(trainloader)
    bsz = len(trainloader[0])

    if not EXPCONF.infer_now:
        albert = albert.albert
        albert.eval()  # freeze

        cls = MLP(EXPCONF, albertconf.hidden_size, 2).to(device)
        cls.train()
        for p in cls.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        # huggingface example is doing this for language modeling...
        # https://github.com/huggingface/transformers/blob/v2.6.0/examples/run_language_modeling.py
        optimizer = AdamW(cls.parameters(),
                          lr=EXPCONF.cls_lr)  # otherwise, use default
        getsch = get_cosine_schedule_with_warmup if EXPCONF.cls_sch == 'cosine' else get_linear_schedule_with_warmup
        scheduler = getsch(optimizer, EXPCONF.cls_warmups,
                           EXPCONF.cls_numsteps)

        ## train cls only!
        while global_step < EXPCONF.cls_numsteps:
            lossep_pp = 0
            accep_pp = 0
            cls.train()
            for i, (b, l, datasetids) in enumerate(
                    tqdm(trainloader, desc="iterations progress"), 1):
                outputs = albert(**b, return_dict=True)
                global_step += 1

                logits = cls(outputs.pooler_output)
                losspp = F.cross_entropy(logits, l)

                lossppval = losspp.item()
                acc = accuracy(logits.clone().detach(), l)

                wandb.log({
                    'step':
                    global_step,
                    'cls.train_step/learning_rate':
                    get_lr_from_optim(optimizer),
                    'cls.train_step/pp_loss':
                    lossppval,
                    'cls.train_step/pp_acc':
                    acc,
                })

                optimizer.step()
                scheduler.step()
                cls.zero_grad()

                lossep_pp += lossppval
                accep_pp += acc
                if global_step % EXPCONF.logevery == 0:
                    lossep_pp /= L
                    accep_pp /= L

                    wandb.log({
                        'cls.train_ep/pp_loss': lossep_pp,
                        'cls.train_ep/pp_acc': accep_pp,
                    })
                    devpp_loss, devpp_acc = evaldev(EXPCONF, albert, cls,
                                                    devloader, global_step)
                    if devpp_acc > EXPCONF.savethld:
                        savemodel(EXPCONF,
                                  albert,
                                  cls,
                                  vocab,
                                  global_step,
                                  acc=devpp_acc)
                        write_sub(EXPCONF,
                                  albert,
                                  cls,
                                  global_step,
                                  acc=devpp_acc,
                                  testloader=testloader)

    else:  # infer now
        cls = None
        devpp_loss, devpp_acc = evaldev(EXPCONF,
                                        albert,
                                        cls,
                                        devloader,
                                        global_step,
                                        infernow=EXPCONF.infer_now)
        write_sub(EXPCONF,
                  albert,
                  cls,
                  global_step,
                  acc=devpp_acc,
                  testloader=testloader,
                  infernow=EXPCONF.infer_now)

    return None
コード例 #13
0
def main():
    # 如果可以使用GPU运算,则使用GPU,否则使用CPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Use " + str(device))

    # 创建输出文件夹
    if not os.path.exists(config.output_path):
        os.mkdir(config.output_path)

    # 创建dataset
    # create dataset
    file_list = None
    for path, dirs, files in os.walk(config.img_path, topdown=False):
        file_list = list(files)

    train_dataset = image_dataset(file_list, config.img_path, transform=get_transforms(config.img_size))
    train_loader = DataLoader(dataset=train_dataset, batch_size=config.batchSize, shuffle=True)

    # 从model中获取判别器D和生成器G的网络模型
    G_model = get_G_model(config.from_old_model, device, config.G_model_path)
    D_model = get_D_model(config.from_old_model, device, config.D_model_path)

    # 定义G和D的优化器,此处使用AdamW优化器
    G_optimizer = AdamW(G_model.parameters(), lr=3e-4, weight_decay=1e-6)
    D_optimizer = AdamW(D_model.parameters(), lr=3e-4, weight_decay=1e-6)

    # 损失函数
    criterion = config.criterion

    # 混合精度加速
    if config.use_apex:
        G_model, G_optimizer = amp.initialize(G_model, G_optimizer, opt_level="O1")
        D_model, D_optimizer = amp.initialize(D_model, D_optimizer, opt_level="O1")

    # 记录训练时间
    train_start = time.time()

    # 开始训练的每一个epoch
    for epoch in range(config.epochs):
        print("start epoch "+str(epoch+1)+":")
        # 定义一些变量用于记录进度和损失
        batch_num = len(train_loader)
        D_loss_sum = 0
        G_loss_sum = 0
        count = 0

        # 从dataloader中提取数据
        for index, images in enumerate(train_loader):
            count += 1
            # 将图片放入运算设备的内存
            images = images.to(device)

            # 定义真标签,使用标签平滑的策略,生成0.9到1之间的随机数作为真实标签
            # real_labels = (1 - torch.rand(config.batchSize, 1)/10).to(device)
            # 定义真标签,全1
            # real_labels = Variable(torch.ones(config.batchSize, 1)).to(device)
            # 定义真标签,全0.9
            real_labels = (Variable(torch.ones(config.batchSize, 1))-0.1).to(device)

            # 定义假标签,单向平滑,因此不对生成器标签进行平滑处理,全0
            fake_labels = Variable(torch.zeros(config.batchSize, 1)).to(device)

            # 将随机的初始数据喂入生成器生成假图像
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)

            # 记录真假标签是否被交换过
            exchange_labels = False

            # 有一定概率在训练判别器时交换label
            if random.uniform(0, 1) < config.D_train_label_exchange:
                real_labels, fake_labels = fake_labels, real_labels
                exchange_labels = True

            # 训练判断器D
            D_optimizer.zero_grad()
            # 用真样本输入判别器
            real_output = D_model(images)

            # 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签
            if len(real_labels) > len(real_output):
                D_loss_real = criterion(real_output, real_labels[:len(real_output)])
            else:
                D_loss_real = criterion(real_output, real_labels)
            # 用假样本输入判别器
            fake_output = D_model(fake_images)
            D_loss_fake = criterion(fake_output, fake_labels)
            # 将真样本与假样本损失相加,得到判别器的损失
            D_loss = D_loss_real + D_loss_fake
            D_loss_sum += D_loss.item()

            # 重置优化器
            D_optimizer.zero_grad()
            # 用损失更新判别器D
            if config.use_apex:
                with amp.scale_loss(D_loss, D_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                D_loss.backward()
            D_optimizer.step()

            # 如果之前交换过标签,此时再换回来
            if exchange_labels:
                real_labels, fake_labels = fake_labels, real_labels

            # 训练生成器G
            # 将随机种子数喂入生成器G生成假数据
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)
            # 将假数据输入判别器
            fake_output = D_model(fake_images)
            # 将假数据的判别结果与真实标签对比得到损失
            G_loss = criterion(fake_output, real_labels)
            G_loss_sum += G_loss.item()

            # 重置优化器
            G_optimizer.zero_grad()
            # 利用损失更新生成器G
            if config.use_apex:
                with amp.scale_loss(G_loss, G_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                G_loss.backward()
            G_optimizer.step()

            # 打印程序工作进度
            if (index + 1) % 200 == 0:
                print("Epoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num))

        if (epoch+1) % 10 == 0:
            # 在每N个epoch结束时保存模型参数到磁盘文件
            torch.save(G_model.state_dict(), config.G_model_path)
            torch.save(D_model.state_dict(), config.D_model_path)
            # 在每N个epoch结束时输出一组生成器产生的图片到输出文件夹
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds).cuda().data
            # 将假图像缩放到[0,1]的区间
            fake_images = 0.5 * (fake_images + 1)
            fake_images = fake_images.clamp(0, 1)
            # 连接所有生成的图片然后用自带的save_image()函数输出到磁盘文件
            fake_images = fake_images.view(-1, 3, config.img_size, config.img_size)
            save_image(fake_images, config.output_path+str(epoch+1)+'.png')

        # 打印该epoch的损失,时间等数据用于参考
        print("D_loss:", round(D_loss_sum / count, 3))
        print("G_loss:", round(G_loss_sum / count, 3))
        current_time = time.time()
        pass_time = int(current_time - train_start)
        time_string = str(pass_time // 3600) + " hours, " + str((pass_time % 3600) // 60) + " minutes, " + str(
            pass_time % 60) + " seconds."
        print("Time pass:"******"Done.")
コード例 #14
0
def main():
    logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler('logs/log'+date_now()+'.log')],
    )
    logger.setLevel(logging.INFO)

    epoch = 10
    batch_size = 64
    data = pd.read_pickle(os.path.join(COMMENT_DIR, 'comment_continue_train_balance.pkl'))
    val_data = pd.read_csv('/ai/223/person/lichunyu/datasets/kaggle/jigsaw/rate/validation_data.csv')
    tokenizer = BertTokenizer.from_pretrained('/ai/223/person/lichunyu/pretrain-models/bert-base-uncased')

    model = BertRegress()
    dataset = JigsawDataset(data, tokenizer)
    less_val_dataset = JigsawValDataset(val_data, 'less_toxic', tokenizer)
    more_val_dataset = JigsawValDataset(val_data, 'more_toxic', tokenizer)
    train_dataloader = DataLoader(dataset, batch_size=batch_size)
    less_val_dataloader = DataLoader(less_val_dataset, batch_size=batch_size)
    more_val_dataloader = DataLoader(more_val_dataset, batch_size=batch_size)
    # optimizer = SGD(model.parameters(), lr=4e-4, weight_decay=2)
    optimizer = AdamW(
        [
            {'params': model.bert.parameters()},
            {'params': model.regress.parameters(), 'lr':5e-4}
        ],
        lr=5e-5,
    )

    model.cuda()

    for e in range(epoch):

        model.train()
        train_total_loss = 0
        step = 0
        for n, batch in enumerate(tqdm(train_dataloader)):
            model.zero_grad()
            step += 1
            input_ids = batch[0].cuda()
            attention_mask = batch[1].cuda()
            y = batch[2].cuda()
            model_output = model(input_ids, attention_mask, y)
            loss = model_output['loss']
            train_total_loss += loss.item()
            if (n % 50) == 0:
                logger.info(f'the loss of batch {n} is {loss.item()}')
            loss.backward()
            optimizer.step()

        logger.info('train step loss is {}'.format(train_total_loss/step))


        model.eval()
        less_toxic_scores = np.array([])
        more_toxic_scores = np.array([])
        for batch in tqdm(less_val_dataloader):
            input_ids = batch[0].cuda()
            attention_mask = batch[1].cuda()
            with torch.no_grad():
                model_output = model(input_ids, attention_mask)
                score = model_output['output']
                score = score.detach().clone().cpu().numpy().flatten()
                less_toxic_scores = np.append(less_toxic_scores, score)

        for batch in tqdm(more_val_dataloader):
            input_ids = batch[0].cuda()
            attention_mask = batch[1].cuda()
            with torch.no_grad():
                model_output = model(input_ids, attention_mask)
                score = model_output['output']
                score = score.detach().clone().cpu().numpy().flatten()
                more_toxic_scores = np.append(more_toxic_scores, score)

        acc_item = (less_toxic_scores < more_toxic_scores).sum()
        logger.info(f'~~~~~~ Acc item is {acc_item}  ~~~~~~~')
        acc = acc_item / len(less_toxic_scores)
        logger.info(f'~~~~~~ Acc score is {acc}  ~~~~~~~')

        current_ckpt = os.path.join(COMMENT_MODEL_DIR, f'bert-epoch-{e}-acc-{acc}.pth')
        torch.save(model.state_dict(), current_ckpt)
コード例 #15
0
ファイル: trainer.py プロジェクト: Jimmy-INL/vietocr
class Trainer():
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)
        
        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']
    
        if logger:
            self.logger = Logger(logger) 

        if pretrained:
            weight_file = download_weights(**config['pretrain'], quiet=config['quiet'])
            self.load_weights(weight_file)

        self.iter = 0
        
        self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer'])
#        self.optimizer = ScheduledOptim(
#            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
#            #config['transformer']['d_model'], 
#            512,
#            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1)
        
        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name), 
                self.data_root, self.train_annotation, transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen('valid_{}'.format(self.dataset_name), 
                    self.data_root, self.valid_annotation)

        self.train_losses = []
        
    def train(self):
        total_loss = 0
        
        total_loader_time = 0
        total_gpu_time = 0
        best_acc = 0

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(self.iter, 
                        total_loss/self.print_every, self.optimizer.param_groups[0]['lr'], 
                        total_loader_time, total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info) 
                self.logger.log(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_loss = self.validate()
                acc_full_seq, acc_per_char = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(self.iter, val_loss, acc_full_seq, acc_per_char)
                print(info)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq

            
    def validate(self):
        self.model.eval()

        total_loss = []
        
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
#                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
               
                outputs = outputs.flatten(0,1)
                tgt_output = tgt_output.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())
                
                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()
        
        return total_loss
    
    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []

        for batch in  self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(batch['img'], self.model)
            else:
                translated_sentence = translate(batch['img'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())

            img_files.extend(batch['filenames'])

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            
            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files

    def precision(self, sample=None):

        pred_sents, actual_sents, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char')
    
        return acc_full_seq, acc_per_char
    
    def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16):
        
        pred_sents, actual_sents, img_files = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i]!= actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]


        img_files = img_files[:sample]

        fontdict = {
                'family':fontname,
                'size':fontsize
                } 

        for vis_idx in range(0, len(img_files)):
            img_path = img_files[vis_idx]
            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]

            img = Image.open(open(img_path, 'rb'))
            plt.figure()
            plt.imshow(img)
            plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent), loc='left', fontdict=fontdict)
            plt.axis('off')

        plt.show()
    
    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1,2,0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())
                
                plt.figure()
                plt.title('sent: {}'.format(sent), loc='center', fontname=fontname)
                plt.imshow(img)
                plt.axis('off')
                
                n += 1
                if n >= sample:
                    plt.show()
                    return


    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)
        
        optim = ScheduledOptim(
	       Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            	self.config['transformer']['d_model'], **self.config['optimizer'])

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {'iter':self.iter, 'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses}
        
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename, map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape'.format(name))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)
       
        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True)

        batch = {
                'img': img, 'tgt_input':tgt_input, 
                'tgt_output':tgt_output, 'tgt_padding_mask':tgt_padding_mask, 
                'filenames': batch['filenames']
                }

        return batch

    def data_gen(self, lmdb_path, data_root, annotation, transform=None):
        dataset = OCRDataset(lmdb_path=lmdb_path, 
                root_dir=data_root, annotation_path=annotation, 
                vocab=self.vocab, transform=transform, 
                image_height=self.config['dataset']['image_height'], 
                image_min_width=self.config['dataset']['image_min_width'], 
                image_max_width=self.config['dataset']['image_max_width'])

        sampler = ClusterRandomSampler(dataset, self.batch_size, True)
        gen = DataLoader(
                dataset,
                batch_size=self.batch_size, 
                sampler=sampler,
                collate_fn = collate_fn,
                shuffle=False,
                drop_last=False,
                **self.config['dataloader'])
       
        return gen

    def data_gen_v1(self, lmdb_path, data_root, annotation):
        data_gen = DataGen(data_root, annotation, self.vocab, 'cpu', 
                image_height = self.config['dataset']['image_height'],        
                image_min_width = self.config['dataset']['image_min_width'],
                image_max_width = self.config['dataset']['image_max_width'])

        return data_gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']    
        
        outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
#        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))#flatten(0, 1)
        tgt_output = tgt_output.view(-1)#flatten()
        
        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) 

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item
コード例 #16
0
def train(model, dataset):
    train_epochs = 2

    tb_writer = SummaryWriter()

    sampler = RandomSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=16)
    # The optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,  # Default for AdamW in torch
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=1,
        num_training_steps=train_epochs * len(dataloader),
    )

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Num Epochs = %d", train_epochs)
    logger.info("  Total optimization steps = %d",
                train_epochs * len(dataloader))

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = tqdm.trange(train_epochs, desc="Epoch")

    log_every = 50
    set_seed(42)  # No idea why here and outside as well
    for _ in train_iterator:
        epoch_iterator = tqdm.tqdm(dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(DEVICE) for t in batch)
            inputs = dict(
                zip(
                    [
                        "input_ids",
                        "attention_mask",
                        "token_type_ids",
                        "labels",
                    ],
                    batch,
                ))
            outputs = model(**inputs)
            loss = outputs[0]
            if torch.cuda.device_count() > 1:
                loss = loss.mean()
            loss.backward()
            tr_loss += loss.item()

            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            if global_step % log_every == 0:
                # Aca tengo caleta de dudas sobre que es lo que tengo que
                # loggear y que es lo que tengo que devolver al final
                # por ahora voy a llegar y copiar nomas
                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar("loss",
                                     (tr_loss - logging_loss) / log_every,
                                     global_step)
                logging_loss = tr_loss

    tb_writer.close()

    return global_step, tr_loss / global_step
コード例 #17
0
class Trainer():
    def __init__(self, train_dataloader, test_dataloader, lr, betas, weight_decay, log_freq, with_cuda, model=None):
        
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda" if cuda_condition else "cpu")
        print("Use:", "cuda:0" if cuda_condition else "cpu")
        
        self.model = cnn_audio().to(self.device)
        self.optim = AdamW(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optim, 5)
        self.criterion = nn.BCEWithLogitsLoss()
        
        if model != None:            
            checkpoint = torch.load(model)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.criterion = checkpoint['loss']


        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        print("Using %d GPUS for Converter" % torch.cuda.device_count())
        
        self.train_data = train_dataloader
        self.test_data = test_dataloader
        
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.test_loss = []
        self.train_loss = []
        self.train_f1_score = []
        self.test_f1_score = []
    
    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        :param epoch: 現在のepoch
        :param data_loader: torch.utils.data.DataLoader
        :param train: trainかtestかのbool値
        """
        str_code = "train" if train else "test"

        data_iter = tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        
        total_element = 0
        loss_store = 0.0
        f1_score_store = 0.0
        total_correct = 0

        for i, data in data_iter:
            specgram = data[0].to(self.device)
            label = data[2].to(self.device)
            one_hot_label = data[1].to(self.device)
            predict_label = self.model(specgram, train)

            # 
            predict_f1_score = get_F1_score(
                label.cpu().detach().numpy(),
                convert_label(predict_label.cpu().detach().numpy()),
                average='micro'
            )
            
            loss = self.criterion(predict_label, one_hot_label)

            # 
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.scheduler.step()

            loss_store += loss.item()
            f1_score_store += predict_f1_score
            self.avg_loss = loss_store / (i + 1)
            self.avg_f1_score = f1_score_store / (i + 1)
        
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": round(self.avg_loss, 5),
                "loss": round(loss.item(), 5),
                "avg_f1_score": round(self.avg_f1_score, 5)
            }

        data_iter.write(str(post_fix))
        self.train_loss.append(self.avg_loss) if train else self.test_loss.append(self.avg_loss)
        self.train_f1_score.append(self.avg_f1_score) if train else self.test_f1_score.append(self.avg_f1_score)
        
    
    def save(self, epoch, file_path="../models/2j-ramdomMask/"):
        """
        """
        output_path = file_path + f"crnn_ep{epoch}.model"
        torch.save(
            {
            'epoch': epoch,
            'model_state_dict': self.model.cpu().state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
            'criterion': self.criterion
            },
            output_path)
        self.model.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def export_log(self, epoch, file_path="../../logs/2j-ramdomMask/"):
        df = pd.DataFrame({
            "train_loss": self.train_loss, 
            "test_loss": self.test_loss, 
            "train_F1_score": self.train_f1_score,
            "test_F1_score": self.test_f1_score
        })
        output_path = file_path+f"loss_timestrech.log"
        print("EP:%d logs Saved on:" % epoch, output_path)
        df.to_csv(output_path)
コード例 #18
0
ファイル: sum_train.py プロジェクト: cschaefer26/NLG
def train(model: AutoModelWithLMHead,
          train_dataset: Dataset,
          val_dataset: Dataset,
          batch_size=32) -> None:

    summary_writer = SummaryWriter('logs/gpt2-training')
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = model.to(device)
    train_sampler = RandomSampler(train_dataset)
    train_loader = DataLoader(dataset=train_dataset,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              num_workers=0,
                              collate_fn=collate_dataset)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            num_workers=0,
                            collate_fn=collate_dataset)
    loss_func = CrossEntropyLoss(ignore_index=0, reduction='sum')
    optimizer = AdamW(model.parameters(), lr=5e-5)

    scheduler = get_polynomial_decay_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=1000,
        num_training_steps=200000,
        lr_end=1e-07,
        power=1.0,
        last_epoch=-1)

    total_step = 0
    for epoch in range(10):
        epoch_iterator = tqdm(train_loader, desc="Training")
        for step, batch in enumerate(epoch_iterator):
            total_step += 1
            inputs, labels = batch['tokens'].to(device), batch['tokens'].to(
                device)
            model.train()
            optimizer.zero_grad()
            logits = model(inputs)[0]
            loss = 0
            norm = 0
            for b, idx in enumerate(batch['abstract_len']):
                shift_logits = logits[b:b + 1, idx:-1, :]
                shift_labels = labels[b:b + 1, idx + 1:]
                b_loss = loss_func(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1))
                loss += b_loss
                norm += shift_labels.size(1)
            loss = loss / norm
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            optimizer.step()
            scheduler.step()
            summary_writer.add_scalar('Loss/train',
                                      loss,
                                      global_step=total_step)
            for param_group in optimizer.param_groups:
                summary_writer.add_scalar('Params/learning_rate',
                                          param_group['lr'],
                                          global_step=total_step)

            # EVALUATION
            if total_step % 1000 == 0:
                model.eval()
                val_loss = 0
                val_norm = 0
                for val_batch in val_loader:
                    inputs, labels = val_batch['tokens'].to(
                        device), val_batch['tokens'].to(device)
                    with torch.no_grad():
                        logits = model(inputs)[0]
                    for b, idx in enumerate(val_batch['abstract_len']):
                        shift_logits = logits[b:b + 1, idx:-1, :]
                        shift_labels = labels[b:b + 1, idx + 1:]
                        b_loss = loss_func(
                            shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
                        val_loss += b_loss
                        val_norm += shift_labels.size(1)
                val_loss = val_loss / val_norm
                summary_writer.add_scalar('Loss/val',
                                          val_loss,
                                          global_step=total_step)
                model.train()

            # GENERATION
            texts = [
                'Deutsche Bank sehr schwach nach Aussagen zum Konzernumbau',
                'Mann nach Sturz in Brunnen schwer verletzt',
                'Unwetterwarnung: Sturm zieht über Bayern',
                'Bayern verliert klar im Pokalfinale gegen Liverpool'
            ]
            if total_step % 1000 == 0:
                model.eval()
                for text in texts:
                    inp = tokenizer.encode(text) + tokenizer.encode('|')
                    gen = generate(model,
                                   context=inp,
                                   length=100,
                                   device=device)
                    gen = tokenizer.decode(gen[0])
                    summary_writer.add_text('Text/Prediction',
                                            '    ' + gen,
                                            global_step=total_step)
                    print(f'step {step}, gen: {gen}')
                model.train()

            if total_step % 50000 == 0:
                torch.save(model.state_dict(),
                           f'models/gpt2_step_{total_step}.pt')

    return None
コード例 #19
0
def train(dataset_path: str):
    device = torch.device(config["train"]["device"])

    print("Device: {}".format(device))

    # device = torch.device("cpu")  # gpu not enough memory :(

    model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-gpt")
    model.to(device)
    tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")

    orig_num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(SPECIAL_TOKENS)
    model.resize_token_embeddings(new_num_tokens=orig_num_tokens +
                                  num_added_tokens)

    # dataloader = get_data_loader(dataset_path, tokenizer, batch_size=4, shuffle=False, num_workers=1)
    full_dataset = get_dataset(dataset_path, tokenizer)
    assert len(full_dataset) > 0
    train_size = int(
        len(full_dataset) * config["train"]["train_dataset_proportion"] + 1)
    test_size = len(full_dataset) - train_size
    print("Full dataset has {} dialogs. Splitting into train: {} and test: {}".
          format(len(full_dataset), train_size, test_size))
    train_dataset, test_dataset = random_split(
        full_dataset, [train_size, test_size],
        torch.Generator().manual_seed(42))
    print(len(train_dataset), len(test_dataset))

    train_loader = get_data_loader(train_dataset, tokenizer,
                                   config["train"]["batch_size"], True, 0)
    test_loader = get_data_loader(test_dataset, tokenizer, 1, False, 0)

    lr = config["train"]["learning_rate"]
    print("lr: {}".format(lr))
    optimizer = AdamW(model.parameters(), lr=lr)

    # init logging
    start_time = datetime.datetime.now()
    save_path = os.path.join(
        os.path.dirname(__file__),
        "log/log-{}.txt".format(start_time.strftime("%y-%m-%d-%H-%M-%S")))
    print(os.path.dirname(__file__), save_path)
    f = open(save_path, "w+")
    f.close()

    epochs = config["train"]["num_epochs"]
    eval_every = config["train"]["evaluate_interval_iters"]
    num_tests = config["train"]["num_tests"]
    last_model_save = datetime.datetime.now()
    iteration = 0

    for epoch in range(epochs):
        print("Starting epoch {}/{}".format(epoch, epochs))
        for batch in train_loader:

            if iteration % eval_every == 0:
                results = evaluate_model(model, test_loader, device, num_tests)
                add_log(
                    save_path,
                    "test,{0},{1},{2[mc_correct]},{2[num_tests]},{2[lm_correct]},{2[lm_tested]}\n"
                    .format(iteration, epoch, results))

            model.train()
            input_ids = batch["input_ids"].to(device)
            mc_token_ids = batch["mc_token_ids"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            lm_labels = batch["lm_labels"].to(device)
            mc_labels = batch["correct"].to(device)

            try:
                model_output = model(input_ids,
                                     token_type_ids=token_type_ids,
                                     mc_token_ids=mc_token_ids,
                                     mc_labels=mc_labels,
                                     labels=lm_labels)
            except Exception as e:
                print(input_ids,
                      token_type_ids,
                      mc_token_ids,
                      lm_labels,
                      mc_labels,
                      sep="\n")
                raise e

            # print("input_ids: {}\ntoken_type_ids: {}\nmc_token_ids: {}\nlm_labels: {}\nmc_labels: {}"
            #       .format(input_ids, token_type_ids, mc_token_ids, lm_labels, mc_labels))

            # print(model_output.loss.item(), model_output.mc_loss.item())
            lm_loss = model_output.loss
            mc_loss = model_output.mc_loss

            loss = lm_loss * config["train"]["lm_coeff"] + mc_loss * config[
                "train"]["mc_coeff"]

            add_log(
                save_path,
                "train,{},{},{},{},{}\n".format(iteration, epoch, loss,
                                                lm_loss, mc_loss))

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(),
                                     config["train"]["max_norm"])
            optimizer.step()
            optimizer.zero_grad()

            # TODO: evaluation

            if iteration % 50 == 0:
                print(
                    "Time: {} Epoch: {}/{} Iteration: {}/{} Loss: {} ({} {})".
                    format(
                        datetime.datetime.now() - start_time, epoch, epochs,
                        iteration,
                        epochs *
                        (len(train_dataset) // config["train"]["batch_size"]),
                        loss.item(), lm_loss.item(), mc_loss.item()))

            if datetime.datetime.now() - last_model_save > datetime.timedelta(
                    minutes=config["train"]["save_interval_mins"]):
                print("Saving model...")
                torch.save(
                    model.state_dict(),
                    os.path.join(os.path.dirname(__file__),
                                 "checkpoints/model-{}-iter{}.pt").format(
                                     start_time.strftime("%y-%m-%d-%H-%M-%S"),
                                     iteration))
                last_model_save = datetime.datetime.now()

            iteration += 1

    print("Saving model...")
    torch.save(
        model.state_dict(),
        os.path.join(os.path.dirname(__file__),
                     "checkpoints/model-{}-iter{}.pt").format(
                         start_time.strftime("%y-%m-%d-%H-%M-%S"), iteration))
コード例 #20
0
class Distiller:
    def __init__(self, params: dict, dataset: CaptionTSVDataset,
                 student: nn.Module, teacher: nn.Module, val_dataset,
                 tokenizer):
        logger.info("Initializing Distiller")
        self.params = params
        self.dump_path = params.output_dir
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        # if params.group_by_size:
        #     groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
        #     sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        # else:
        #     sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

        sampler = BatchSampler(sampler=sampler,
                               batch_size=params.batch_size,
                               drop_last=False)

        self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler)
        self.val_dataset = val_dataset
        self.tokenizer = tokenizer

        self.eval_log = []

        self.temperature = params.temperature
        assert self.temperature > 0.0

        self.alpha_ce = params.alpha_ce
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        # self.mlm = params.mlm
        # if self.mlm:
        #     logger.info("Using MLM loss for LM step.")
        #     self.mlm_mask_prop = params.mlm_mask_prop
        #     assert 0.0 <= self.mlm_mask_prop <= 1.0
        #     assert params.word_mask + params.word_keep + params.word_rand == 1.0
        #     self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
        #     self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
        #     self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
        #     if self.fp16:
        #         self.pred_probs = self.pred_probs.half()
        #         self.token_probs = self.token_probs.half()
        # else:
        #     logger.info("Using CLM loss for LM step.")

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        logger.info("--- Initializing model optimizer")
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps *
                params.n_epoch) + 1)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                0.0,
            },
        ]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel

                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel

                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )

        # self.is_master = params.is_master
        # if self.is_master:
        logger.info("--- Initializing Tensorboard")
        self.tensorboard = SummaryWriter(
            log_dir=os.path.join(self.dump_path, "log", "train"))
        self.tensorboard.add_text(tag="config/training",
                                  text_string=str(self.params),
                                  global_step=0)
        self.tensorboard.add_text(tag="config/student",
                                  text_string=str(self.student_config),
                                  global_step=0)

    def train(self):
        """
        The real training loop.
        """
        logger.info("Starting training")
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            logger.info(
                f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader,
                            desc="-Iter",
                            disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    img_key, example = batch
                    # img_key = img_key.to(f"cuda:{self.params.local_rank}")
                    example = tuple(
                        t.to(f"cuda:{self.params.local_rank}")
                        for t in example)
                '''CaptionTSVDataset:
                def __getitem__(self, idx):
                        img_idx = self.get_image_index(idx)
                        img_key = self.image_keys[img_idx]
                        features = self.get_image_features(img_idx)
                        caption = self.get_caption(idx)
                        od_labels = self.get_od_labels(img_idx)
                        example = self.tensorizer.tensorize_example(caption, features, text_b=od_labels)
                        return img_key, example
                '''

                # example: (input_ids, attention_mask, segment_ids, img_feat, masked_pos)

                inputs = {
                    'input_ids': example[0],
                    'attention_mask': example[1],
                    'token_type_ids': example[2],
                    'img_feats': example[3],
                    'masked_pos': example[4],
                    'masked_ids': example[5]
                }
                outputs = self.step(**inputs)

                iter_bar.update()
                iter_bar.set_postfix({
                    "Last_loss":
                    f"{self.last_loss:.2f}",
                    "Avg_cum_loss":
                    f"{self.total_loss_epoch/self.n_iter:.2f}"
                })
            iter_bar.close()

            logger.info(
                f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
            self.end_epoch()

        logger.info("Save very last checkpoint as `pytorch_model.bin`.")
        self.save_checkpoint(checkpoint_name="pytorch_model.bin")
        logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             token_type_ids: torch.tensor, img_feats: torch.tensor,
             masked_pos: torch.tensor, masked_ids: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """

        s_logits, s_hidden_states = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
            img_feats=img_feats,
            masked_pos=masked_pos,
            masked_ids=masked_ids,
            token_type_ids=token_type_ids)  # (bs, seq_length, voc_size)
        with torch.no_grad():
            t_output = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
                img_feats=img_feats,
                masked_pos=masked_pos,
                masked_ids=masked_ids,
                token_type_ids=token_type_ids)  # (bs, seq_length, voc_size)
            _, t_logits, t_hidden_states = t_output

        # output shape (num_blanks, voc_size)

        # mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
        # s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        # s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        # t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        # t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask

        s_logits_slct = s_logits
        t_logits_slct = t_logits
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = (self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature, dim=-1),
        ) * (self.temperature)**2)
        loss = self.alpha_ce * loss_ce

        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            # mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
            # assert s_hidden_states.size() == t_hidden_states.size()
            # dim = s_hidden_states.size(-1)

            # s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
            # s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
            # t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
            # t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

            s_hidden_states_slct = s_hidden_states.reshape(1, -1)
            t_hidden_states_slct = t_hidden_states.reshape(1, -1)

            target = torch.ones(s_hidden_states_slct.shape).to(
                s_hidden_states_slct.device)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()
            logger.info("Perform evaluation at step: %d" % (self.n_total_iter))
            try:
                evaluate_file = evaluate(self.params, self.val_dataset,
                                         self.student, self.tokenizer,
                                         self.dump_path)
                with open(evaluate_file, 'r') as f:
                    res = json.load(f)
                best_score = max(best_score, res['CIDEr'])
                res['epoch'] = epoch
                res['global_step'] = step
                res['best_CIDEr'] = best_score
                self.eval_log.append(res)
                with open(self.dump_path + '/eval_logs.json', 'w') as f:
                    json.dump(eval_log, f)
            except:
                print("An exception was made in the evaluation process. ")

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        # if not self.is_master:
        #     return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag="parameter_mean/" + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="parameter_std/" + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_cos",
                                        scalar_value=self.last_loss_cos,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="global/speed",
                                    scalar_value=time.time() - self.last_log,
                                    global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f"{self.n_sequences_epoch} sequences have been trained during this epoch."
        )

        self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
        self.tensorboard.add_scalar(tag="epoch/loss",
                                    scalar_value=self.total_loss_epoch /
                                    self.n_iter,
                                    global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        # if not self.is_master:
        #     return
        mdl_to_save = self.student.module if hasattr(
            self.student, "module") else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
コード例 #21
0
class StudentTrainerSigma:
    def __init__(self, train_loader, val_loader, val_loader_unit_batch,
                 noise_size, student_hidden_size, student_num_layers, epochs,
                 start_lr, teacher_generator):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.val_loader_unit_batch = val_loader_unit_batch

        self.student_num_layers = student_num_layers

        self.teacher_generator = teacher_generator
        self.student_generator = StudentGenerator(noise_size,
                                                  student_hidden_size,
                                                  student_num_layers,
                                                  128).to(device)

        self.start_lr = start_lr

        print(self.teacher_generator)
        print(self.student_generator)

        self.epochs = epochs

    def train(self, metric_every, savename, sigma, val_size=None):
        if not val_size:
            val_size = len(self.val_loader)

        if os.path.isfile(f"results/{savename}.txt"):
            os.remove(f"results/{savename}.txt")

        self.optimizer = AdamW(self.student_generator.parameters(),
                               lr=self.start_lr)
        lr_track = np.logspace(0, -0.01, num=self.epochs)
        lr_lambda = lambda x: lr_track[x]
        self.sheduler = LambdaLR(self.optimizer, lr_lambda)
        for epoch in range(1, self.epochs + 1):
            print(f"epoch {epoch} (idle)")

            avg_loss = 0

            for (real, noised, w) in tqdm(self.train_loader):
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                disturbance = torch.tensor(
                    np.random.normal(loc=0.0,
                                     scale=sigma,
                                     size=teacher_generated.size())).to(device)

                teacher_generated = (1 + disturbance) * teacher_generated

                loss = ((teacher_generated - student_generated)**2).mean()

                avg_loss += loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if epoch != self.epochs:
                self.sheduler.step()
            avg_loss = avg_loss / len(self.train_loader)
            print(f"train loss: {avg_loss:.6f}")

            if epoch % metric_every == 0:
                self.validate(val_size, epoch, savename, avg_loss)

    def validate(self, val_size, epoch, savename, avg_loss):
        with torch.no_grad():
            n = 0
            start = time.time()
            for (_, noised, _) in self.val_loader_unit_batch:
                _ = self.teacher_generator(noised)
                n += 1
                if n == 1000:
                    break
            teacher_time = (time.time() - start) / 1000

            n = 0
            start = time.time()
            for (_, noised, _) in self.val_loader_unit_batch:
                _ = self.student_generator(noised)
                n += 1
                if n == 1000:
                    break
            student_time = (time.time() - start) / 1000

            teacher_ms = teacher_time * 1000
            student_ms = student_time * 1000

            print(
                f"avg teacher: {teacher_ms:.3f}ms, avg student: {student_ms:.3f}ms"
            )

            real_batches = []
            teacher_gen_batches = []
            student_gen_batches = []
            w_batches = []
            step = 0
            for (real, noised, w) in self.val_loader:
                step += 1
                if step > val_size:
                    break
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                real_batches.append(real.detach().cpu())
                teacher_gen_batches.append(teacher_generated.detach().cpu())
                student_gen_batches.append(student_generated.detach().cpu())
                w_batches.append(w.detach().cpu())

        info = {
            "teacher time, ms": teacher_ms,
            "student time, ms": student_ms,
            "loss": avg_loss,
            "epoch": epoch
        }

        plot_metrics(torch.cat(real_batches, dim=0),
                     torch.cat(teacher_gen_batches, dim=0),
                     torch.cat(student_gen_batches, dim=0),
                     torch.cat(w_batches, dim=0), info, savename)
コード例 #22
0
ファイル: morph_train.py プロジェクト: aseker00/alephbert
def process(model: MorphSequenceModel,
            data: DataLoader,
            criterion: nn.CrossEntropyLoss,
            epoch,
            phase,
            print_every,
            teacher_forcing_ratio=0.0,
            optimizer: optim.AdamW = None,
            max_grad_norm=None):
    print_form_loss, total_form_loss = 0, 0
    print_label_losses, total_label_losses = [
        0 for _ in range(len(label_names))
    ], [0 for _ in range(len(label_names))]
    print_target_forms, total_target_forms = [], []
    print_target_labels, total_target_labels = [], []
    print_decoded_forms, total_decoded_forms = [], []
    print_decoded_labels, total_decoded_labels = [], []
    print_decoded_lattice_rows, total_decoded_lattice_rows = [], []

    for i, batch in enumerate(data):
        batch = tuple(t.to(device) for t in batch)
        batch_form_scores, batch_label_scores, batch_form_targets, batch_label_targets = [], [], [], []
        batch_token_chars, batch_sent_ids, batch_num_tokens = [], [], []
        for sent_xtoken, sent_token_chars, sent_form_chars, sent_labels in zip(
                *batch):
            input_token_chars = sent_token_chars[:, :, -1]
            num_tokens = len(sent_token_chars[sent_token_chars[:, 0, 1] > 0])
            target_token_form_chars = sent_form_chars[:, :, -1]
            max_form_len = target_token_form_chars.shape[1]
            target_token_labels = sent_labels[:, :, 2:]
            max_num_labels = target_token_labels.shape[1]
            use_teacher_forcing = True if random.random(
            ) < teacher_forcing_ratio else False
            form_scores, _, label_scores = model(
                sent_xtoken, input_token_chars, char_special_symbols,
                num_tokens, max_form_len, max_num_labels,
                target_token_form_chars if use_teacher_forcing else None)
            batch_form_scores.append(form_scores)
            batch_label_scores.append(label_scores)
            batch_form_targets.append(target_token_form_chars[:num_tokens])
            batch_label_targets.append(target_token_labels[:num_tokens])
            batch_token_chars.append(input_token_chars[:num_tokens])
            batch_sent_ids.append(sent_form_chars[:, :, 0].unique().item())
            batch_num_tokens.append(num_tokens)

        # Decode
        batch_form_scores = nn.utils.rnn.pad_sequence(batch_form_scores,
                                                      batch_first=True)
        batch_label_scores = [
            nn.utils.rnn.pad_sequence(label_scores, batch_first=True)
            for label_scores in list(map(list, zip(*batch_label_scores)))
        ]
        with torch.no_grad():
            batch_decoded_chars, batch_decoded_labels = model.decode(
                batch_form_scores, batch_label_scores)

        # Form Loss
        batch_form_targets = nn.utils.rnn.pad_sequence(batch_form_targets,
                                                       batch_first=True)
        form_loss = model.form_loss(batch_form_scores, batch_form_targets,
                                    criterion)
        print_form_loss += form_loss.item()

        # Label Losses
        batch_label_targets = [[t[:, :, j] for j in range(t.shape[-1])]
                               for t in batch_label_targets]
        batch_label_targets = [
            nn.utils.rnn.pad_sequence(label_targets, batch_first=True)
            for label_targets in list(map(list, zip(*batch_label_targets)))
        ]
        label_losses = model.labels_losses(batch_label_scores,
                                           batch_label_targets, criterion)
        for j in range(len(label_losses)):
            print_label_losses[j] += label_losses[j].item()

        # Optimization Step
        if optimizer is not None:
            form_loss.backward(retain_graph=len(label_losses) > 0)
            for j in range(len(label_losses)):
                label_losses[j].backward(
                    retain_graph=(j < len(label_losses) - 1))
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        # To Lattice
        for j in range(len(batch_sent_ids)):
            sent_id = batch_sent_ids[j]
            input_chars = batch_token_chars[j]
            target_form_chars = batch_form_targets[j]
            target_labels = [
                label_targets[j] for label_targets in batch_label_targets
            ]
            decoded_form_chars = batch_decoded_chars[j]
            decoded_labels = [
                decoded_labels[j] for decoded_labels in batch_decoded_labels
            ]
            num_tokens = batch_num_tokens[j]
            input_chars = input_chars.to('cpu')
            target_form_chars = target_form_chars[:num_tokens].to('cpu')
            decoded_form_chars = decoded_form_chars[:num_tokens].to('cpu')
            target_labels = [
                labels[:num_tokens].to('cpu') for labels in target_labels
            ]
            decoded_labels = [
                labels[:num_tokens].to('cpu') for labels in decoded_labels
            ]
            input_tokens = utils.to_sent_tokens(input_chars,
                                                char_vocab['id2char'])
            target_morph_segments = utils.to_token_morph_segments(
                target_form_chars, char_vocab['id2char'], char_eos, char_sep)
            decoded_morph_segments = utils.to_token_morph_segments(
                decoded_form_chars, char_vocab['id2char'], char_eos, char_sep)
            target_morph_labels = utils.to_token_morph_labels(
                target_labels, label_names, label_vocab['id2labels'],
                label_pads)
            decoded_morph_labels = utils.to_token_morph_labels(
                decoded_labels, label_names, label_vocab['id2labels'],
                label_pads)

            decoded_token_lattice_rows = (sent_id, input_tokens,
                                          decoded_morph_segments,
                                          decoded_morph_labels)
            print_decoded_lattice_rows.append(decoded_token_lattice_rows)
            print_target_forms.append(target_morph_segments)
            print_target_labels.append(target_morph_labels)
            print_decoded_forms.append(decoded_morph_segments)
            print_decoded_labels.append(decoded_morph_labels)

        # Log Print Eval
        if (i + 1) % print_every == 0:
            sent_id, input_tokens, decoded_segments, decoded_labels = print_decoded_lattice_rows[
                -1]
            target_segments = print_target_forms[-1]
            target_labels = print_target_labels[-1]
            decoded_segments = print_decoded_forms[-1]
            decoded_labels = print_decoded_labels[-1]

            print(
                f'epoch {epoch} {phase}, batch {i + 1} form char loss: {print_form_loss / print_every}'
            )
            for j in range(len(label_names)):
                print(
                    f'epoch {epoch} {phase}, batch {i + 1} {label_names[j]} loss: {print_label_losses[j] / print_every}'
                )
            print(
                f'epoch {epoch} {phase}, batch {i + 1} sent #{sent_id} input tokens  : {input_tokens}'
            )
            print(
                f'epoch {epoch} {phase}, batch {i + 1} sent #{sent_id} target forms  : {list(reversed(target_segments))}'
            )
            print(
                f'epoch {epoch} {phase}, batch {i + 1} sent #{sent_id} decoded forms : {list(reversed(decoded_segments))}'
            )
            for j in range(len(label_names)):
                target_values = [labels[j] for labels in target_labels]
                print(
                    f'epoch {epoch} {phase}, batch {i + 1} sent #{sent_id} target {label_names[j]} labels  : {list(reversed([target_values]))}'
                )
                decoded_values = [labels[j] for labels in decoded_labels]
                print(
                    f'epoch {epoch} {phase}, batch {i + 1} sent #{sent_id} decoded {label_names[j]} labels : {list(reversed([decoded_values]))}'
                )
            total_form_loss += print_form_loss
            for j, label_loss in enumerate(print_label_losses):
                total_label_losses[j] += label_loss
            print_form_loss = 0
            print_label_losses = [0 for _ in range(len(label_names))]

            total_decoded_forms.extend(print_decoded_forms)
            total_decoded_labels.extend(print_decoded_labels)
            total_target_forms.extend(print_target_forms)
            total_target_labels.extend(print_target_labels)
            total_decoded_lattice_rows.extend(print_decoded_lattice_rows)

            aligned_scores, mset_scores = utils.morph_eval(
                print_decoded_forms, print_target_forms)
            # print(f'epoch {epoch} {phase}, batch {i + 1} form aligned scores: {aligned_scores}')
            print(
                f'epoch {epoch} {phase}, batch {i + 1} form mset scores: {mset_scores}'
            )

            for j in range(len(label_names)):
                if label_names[j][:3].lower() in [
                        'tag', 'bio', 'gen', 'num', 'per', 'ten'
                ]:
                    decoded_values = [
                        labels[j] for sent_labels in print_decoded_labels
                        for labels in sent_labels
                    ]
                    target_values = [
                        labels[j] for sent_labels in print_target_labels
                        for labels in sent_labels
                    ]
                    aligned_scores, mset_scores = utils.morph_eval(
                        decoded_values, target_values)
                    # print(f'epoch {epoch} {phase}, batch {i + 1} {label_names[j]} aligned scores: {aligned_scores}')
                    print(
                        f'epoch {epoch} {phase}, batch {i + 1} {label_names[j]} mset scores: {mset_scores}'
                    )

            print_target_forms = []
            print_target_labels = []
            print_decoded_forms = []
            print_decoded_labels = []
            print_decoded_lattice_rows = []

    # Log Total Eval
    if print_form_loss > 0:
        total_form_loss += print_form_loss
        for j, label_loss in enumerate(print_label_losses):
            total_label_losses[j] += label_loss
        total_decoded_forms.extend(print_decoded_forms)
        total_decoded_labels.extend(print_decoded_labels)
        total_target_forms.extend(print_target_forms)
        total_target_labels.extend(print_target_labels)
        total_decoded_lattice_rows.extend(print_decoded_lattice_rows)

    print(
        f'epoch {epoch} {phase}, total form char loss: {total_form_loss / len(data)}'
    )
    for j in range(len(label_names)):
        print(
            f'epoch {epoch} {phase}, total {label_names[j]} loss: {total_label_losses[j] / len(data)}'
        )

    for j in range(len(label_names)):
        if label_names[j][:3].lower() in [
                'tag', 'bio', 'gen', 'num', 'per', 'ten'
        ]:
            decoded_values = [
                labels[j] for sent_labels in total_decoded_labels
                for labels in sent_labels
            ]
            target_values = [
                labels[j] for sent_labels in total_target_labels
                for labels in sent_labels
            ]
            aligned_scores, mset_scores = utils.morph_eval(
                decoded_values, target_values)
            # print(f'epoch {epoch} {phase}, total {label_names[j]} aligned scores: {aligned_scores}')
            print(
                f'epoch {epoch} {phase}, total {label_names[j]} mset scores: {mset_scores}'
            )

    return utils.get_lattice_data(total_decoded_lattice_rows, label_names)
コード例 #23
0
class StudentTrainerExp:
    def __init__(self, train_loader, val_loader, val_loader_unit_batch,
                 noise_size, student_hidden_size, student_num_layers, epochs,
                 start_lr, teacher_generator):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.val_loader_unit_batch = val_loader_unit_batch

        self.student_num_layers = student_num_layers

        self.teacher_generator = teacher_generator
        self.student_generator = StudentGenerator(noise_size,
                                                  student_hidden_size,
                                                  student_num_layers,
                                                  128).to(device)

        self.start_lr = start_lr

        print(self.teacher_generator)
        print(self.student_generator)

        self.epochs = epochs

        self.rvs_history = []
        self.tvs_history = []
        self.mse_history = []
        self.epoch_history = []

    def train(self, metric_every, savename, val_size=None):
        if not val_size:
            val_size = len(self.val_loader)

        self.optimizer = AdamW(self.student_generator.parameters(),
                               lr=self.start_lr)
        lr_track = np.logspace(0, -0.01, num=self.epochs)
        lr_lambda = lambda x: lr_track[x]
        self.sheduler = LambdaLR(self.optimizer, lr_lambda)
        for epoch in range(1, self.epochs + 1):
            print(f"epoch {epoch} (idle)")

            avg_loss = 0

            for (real, noised, w) in tqdm(self.train_loader):
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                loss = ((teacher_generated - student_generated)**2).mean()

                avg_loss += loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if epoch != self.epochs:
                self.sheduler.step()
            avg_loss = avg_loss / len(self.train_loader)
            print(f"train loss: {avg_loss:.6f}")

            if epoch % metric_every == 0:
                teacher_vs_student = self.validate(val_size, epoch, savename)

                self.tvs_history.append(teacher_vs_student)
                self.mse_history.append(avg_loss)
                self.epoch_history.append(epoch)

                plt.plot(self.epoch_history, self.mse_history, marker='o')
                plt.tight_layout()
                plt.savefig(f"exp/mse_{epoch}e.png", dpi=300)
                plt.show()

                plt.plot(self.epoch_history, self.tvs_history, marker='o')
                plt.tight_layout()
                plt.savefig(f"exp/tvs_{epoch}e.png", dpi=300)
                plt.show()

    def validate(self, val_size, epoch, savename):
        with torch.no_grad():
            real_batches = []
            teacher_gen_batches = []
            student_gen_batches = []
            w_batches = []
            step = 0
            for (real, noised, w) in self.val_loader:
                step += 1
                if step > val_size:
                    break
                teacher_generated = self.teacher_generator(noised).detach()
                student_generated = self.student_generator(noised)

                real_batches.append(real.detach().cpu())
                teacher_gen_batches.append(teacher_generated.detach().cpu())
                student_gen_batches.append(student_generated.detach().cpu())
                w_batches.append(w.detach().cpu())

        info = {"epoch": epoch}

        return plot_metrics_exp(torch.cat(real_batches, dim=0),
                                torch.cat(teacher_gen_batches, dim=0),
                                torch.cat(student_gen_batches, dim=0),
                                torch.cat(w_batches, dim=0), info, savename)
コード例 #24
0
class Detector(object):
    def __init__(self, cfg):
        self.device = cfg["device"]
        self.model = Models().get_model(cfg["network"]) # cfg.network
        self.model.to(self.device)
        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=0.00001)
        self.lr_scheduler = OneCycleLR(self.optimizer,
                                       max_lr=1e-4,
                                       epochs=cfg["nepochs"],
                                       steps_per_epoch=169,  # len(dataloader)/accumulations
                                       div_factor=25,  # for initial lr, default: 25
                                       final_div_factor=1e3,  # for final lr, default: 1e4
                                       )

    def fit(self, data_loader, accumulation_steps=4, wandb=None):
        self.model.train()
        #     metric_logger = utils.MetricLogger(delimiter="  ")
        #     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        avg_loss = MetricLogger('scalar')
        total_loss = MetricLogger('dict')
        lr_log = MetricLogger('list')

        self.optimizer.zero_grad()
        device = self.device

        for i, (images, targets) in enumerate(data_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.detach().item()
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            losses.backward()
            if (i+1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    lr_log.update(self.lr_scheduler.get_last_lr())


            print(f"\rTrain iteration: [{i+1}/{len(data_loader)}]", end="")
            avg_loss.update(loss_value)
            total_loss.update(loss_dict)

            # metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        print()
        #print(loss_dict)
        return {"train_avg_loss": avg_loss.avg}, total_loss.avg


    def mixup_fit(self, data_loader, accumulation_steps=4, wandb=None):
        self.model.train()
        torch.cuda.empty_cache()
        #     metric_logger = utils.MetricLogger(delimiter="  ")
        #     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        avg_loss = MetricLogger('scalar')
        total_loss = MetricLogger('dict')
        #lr_log = MetricLogger('list')

        self.optimizer.zero_grad()
        device = self.device

        for i, (batch1, batch2) in enumerate(data_loader):
            images1, targets1 = batch1
            images2, targets2 = batch2
            images = mixup_images(images1, images2)
            targets = merge_targets(targets1, targets2)
            del images1, images2, targets1, targets2, batch1, batch2

            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.detach().item()
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            losses.backward()
            if (i+1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    #lr_log.update(self.lr_scheduler.get_last_lr())


            print(f"Train iteration: [{i+1}/{674}]\r", end="")
            avg_loss.update(loss_value)
            total_loss.update(loss_dict)

            # metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        print()
        #print(loss_dict)
        return {"train_avg_loss": avg_loss.avg}, total_loss.avg


    def evaluate(self, val_dataloader):
        device = self.device
        torch.cuda.empty_cache()
        # self.model.to(device)
        self.model.eval()
        mAp_logger = MetricLogger('list')
        with torch.no_grad():
            for (j, batch) in enumerate(val_dataloader):
                print(f"\rValidation: [{j+1}/{len(val_dataloader)}]", end="")
                images, targets = batch
                del batch
                images = [img.to(device) for img in images]
                # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                predictions = self.model(images)#, targets)
                for i, pred in enumerate(predictions):
                    probas = pred["scores"].detach().cpu().numpy()
                    mask = probas > 0.6
                    preds = pred["boxes"].detach().cpu().numpy()[mask]
                    gts = targets[i]["boxes"].detach().cpu().numpy()
                    score, scores = map_score(gts, preds, thresholds=[.5, .55, .6, .65, .7, .75])
                    mAp_logger.update(scores)
            print()
        return {"validation_mAP_score": mAp_logger.avg}

    def get_checkpoint(self):
        self.model.eval()
        model_state = self.model.state_dict()
        optimizer_state = self.optimizer.state_dict()
        checkpoint = {'model_state_dict': model_state,
                      'optimizer_state_dict': optimizer_state
                      }
        # if self.lr_scheduler:
        #     scheduler_state = self.lr_scheduler.state_dict()
        #     checkpoint['lr_scheduler_state_dict'] = scheduler_state

        return checkpoint

    def load_checkpoint(self, checkpoint):
        self.model.eval()
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
コード例 #25
0
        # x_df = x_df.style.background_gradient(cmap='Greys', subset=slice(0,3,1))
        x_df = x_df.style.background_gradient(cmap='Greys', subset=slice(0, 2))

        placeholders_[0][0].write(x_df)

        y_df = pd.DataFrame(data=y.detach().numpy())
        y_df = y_df.style.background_gradient(cmap='Greys', axis=None)
        placeholders_[1][0].write(y_df)
        output = net(x.flatten()).reshape((3, 4))
        loss = criterion(output, y)

        out_df = pd.DataFrame(data=output.detach().numpy())
        out_df = out_df.style.background_gradient(cmap='Greys', axis=None)
        placeholders_[2][0].write(out_df)
        print(f'Loss: {loss.detach()}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        params = net.parameters()
        # print(list(enumerate(params)))
        for i, param in enumerate(params):
            if i == 0:
                placeholders[i][0].write(param.reshape(-1, 3).detach().numpy())
            else:
                placeholders[i][0].write(param.detach().numpy())
        # print(params)
        # exit()
        # placeholders[0][0].write()
コード例 #26
0
class Model:
    def __init__(self, local_rank=-1):
        self.flownet = IFNet()
        self.device()
        self.optimG = AdamW(self.flownet.parameters(),
                            lr=1e-6,
                            weight_decay=1e-4)
        self.epe = EPE()
        # self.vgg = VGGPerceptualLoss().to(device)
        self.sobel = SOBEL()
        if local_rank != -1:
            self.flownet = DDP(self.flownet,
                               device_ids=[local_rank],
                               output_device=local_rank)

    def train(self):
        self.flownet.train()

    def eval(self):
        self.flownet.eval()

    def device(self):
        self.flownet.to(device)

    def load_model(self, path, rank=0):
        def convert(param):
            if rank == -1:
                return {
                    k.replace("module.", ""): v
                    for k, v in param.items() if "module." in k
                }
            else:
                return param

        if rank <= 0:
            if torch.cuda.is_available():
                self.flownet.load_state_dict(
                    convert(torch.load('{}/flownet.pkl'.format(path))))
            else:
                self.flownet.load_state_dict(
                    convert(
                        torch.load('{}/flownet.pkl'.format(path),
                                   map_location='cpu')))

    def save_model(self, path, rank=0):
        if rank == 0:
            torch.save(self.flownet.state_dict(),
                       '{}/flownet.pkl'.format(path))

    def inference(self, img0, img1, scale=1.0):
        imgs = torch.cat((img0, img1), 1)
        scale_list = [4, 2, 1]
        flow, mask, merged = self.flownet(imgs, scale_list, scale=scale)
        return merged[2]

    def update(self,
               imgs,
               gt,
               learning_rate=0,
               mul=1,
               training=True,
               flow_gt=None):
        for param_group in self.optimG.param_groups:
            param_group['lr'] = learning_rate
        img0 = imgs[:, :3]
        img1 = imgs[:, 3:]
        if training:
            self.train()
        else:
            self.eval()
        scale = [4, 2, 1]
        flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1),
                                          scale=scale,
                                          training=training)
        loss_l1 = (merged[2] - gt).abs().mean()
        loss_smooth = self.sobel(flow[2], flow[2] * 0).mean()
        # loss_vgg = self.vgg(merged[2], gt)
        if training:
            self.optimG.zero_grad()
            loss_G = loss_cons + loss_smooth * 0.1
            loss_G.backward()
            self.optimG.step()
        else:
            flow_teacher = flow[2]
        return merged[2], {
            'mask': mask,
            'flow': flow[2][:, :2],
            'loss_l1': loss_l1,
            'loss_cons': loss_cons,
            'loss_smooth': loss_smooth,
        }
コード例 #27
0
class Trainer:
    """
    Handles model training and evaluation.
    
    Arguments:
    ----------
    config: A dictionary of training parameters, likely from a .yaml
    file
    
    model: A pytorch segmentation model (e.g. DeepLabV3)
    
    trn_data: A pytorch dataloader object that will return pairs of images and
    segmentation masks from a training dataset
    
    val_data: A pytorch dataloader object that will return pairs of images and
    segmentation masks from a validation dataset.
    
    """
    def __init__(self, config, model, trn_data, val_data=None):
        self.config = config
        self.model = model.cuda()
        self.trn_data = DataFetcher(trn_data)
        self.val_data = val_data

        #create the optimizer
        if config['optim'] == 'SGD':
            self.optimizer = SGD(model.parameters(),
                                 lr=config['lr'],
                                 momentum=config['momentum'],
                                 weight_decay=config['wd'])
        elif config['optim'] == 'AdamW':
            self.optimizer = AdamW(
                model.parameters(), lr=config['lr'],
                weight_decay=config['wd'])  #momentum is default
        else:
            optim = config['optim']
            raise Exception(
                f'Optimizer {optim} is not supported! Must be SGD or AdamW')

        #create the learning rate scheduler
        schedule = config['lr_policy']
        if schedule == 'OneCycle':
            self.scheduler = OneCycleLR(self.optimizer,
                                        config['lr'],
                                        total_steps=config['iters'])
        elif schedule == 'MultiStep':
            self.scheduler = MultiStepLR(self.optimizer,
                                         milestones=config['lr_decay_epochs'])
        elif schedule == 'Poly':
            func = lambda iteration: (1 - (iteration / config['iters'])
                                      )**config['power']
            self.scheduler = LambdaLR(self.optimizer, func)
        else:
            lr_policy = config['lr_policy']
            raise Exception(
                f'Policy {lr_policy} is not supported! Must be OneCycle, MultiStep or Poly'
            )

        #create the loss criterion
        if config['num_classes'] > 1:
            #load class weights if they were given in the config file
            if 'class_weights' in config:
                weight = torch.Tensor(config['class_weights']).float().cuda()
            else:
                weight = None

            self.criterion = nn.CrossEntropyLoss(weight=weight).cuda()
        else:
            self.criterion = nn.BCEWithLogitsLoss().cuda()

        #define train and validation metrics and class names
        class_names = config['class_names']

        #make training metrics using the EMAMeter. this meter gives extra
        #weight to the most recent metric values calculated during training
        #this gives a better reflection of how well the model is performing
        #when the metrics are printed
        trn_md = {
            name: metric_lookup[name](EMAMeter())
            for name in config['metrics']
        }
        self.trn_metrics = ComposeMetrics(trn_md, class_names)
        self.trn_loss_meter = EMAMeter()

        #the only difference between train and validation metrics
        #is that we use the AverageMeter. this is because there are
        #no weight updates during evaluation, so all batches should
        #count equally
        val_md = {
            name: metric_lookup[name](AverageMeter())
            for name in config['metrics']
        }
        self.val_metrics = ComposeMetrics(val_md, class_names)
        self.val_loss_meter = AverageMeter()

        self.logging = config['logging']

        #now, if we're resuming from a previous run we need to load
        #the state for the model, optimizer, and schedule and resume
        #the mlflow run (if there is one and we're using logging)
        if config['resume']:
            self.resume(config['resume'])
        elif self.logging:
            #if we're not resuming, but are logging, then we
            #need to setup mlflow with a new experiment
            #everytime that Trainer is instantiated we want to
            #end the current active run and let a new one begin
            mlflow.end_run()

            #extract the experiment name from config so that
            #we know where to save our files, if experiment name
            #already exists, we'll use it, otherwise we create a
            #new experiment
            mlflow.set_experiment(self.config['experiment_name'])

            #add the config file as an artifact
            mlflow.log_artifact(config['config_file'])

            #we don't want to add everything in the config
            #to mlflow parameters, we'll just add the most
            #likely to change parameters
            mlflow.log_param('lr_policy', config['lr_policy'])
            mlflow.log_param('optim', config['optim'])
            mlflow.log_param('lr', config['lr'])
            mlflow.log_param('wd', config['wd'])
            mlflow.log_param('bsz', config['bsz'])
            mlflow.log_param('momentum', config['momentum'])
            mlflow.log_param('iters', config['iters'])
            mlflow.log_param('epochs', config['epochs'])
            mlflow.log_param('encoder', config['encoder'])
            mlflow.log_param('finetune_layer', config['finetune_layer'])
            mlflow.log_param('pretraining', config['pretraining'])

    def resume(self, checkpoint_fpath):
        """
        Sets model parameters, scheduler and optimizer states to the
        last recorded values in the given checkpoint file.
        """
        checkpoint = torch.load(checkpoint_fpath, map_location='cpu')
        self.model.load_state_dict(checkpoint['state_dict'])

        if not self.config['restart_training']:
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        if self.logging and 'run_id' in checkpoint:
            mlflow.start_run(run_id=checkpoint['run_id'])

        print(f'Loaded state from {checkpoint_fpath}')
        print(f'Resuming from epoch {self.scheduler.last_epoch}...')

    def log_metrics(self, step, dataset):
        #get the corresponding losses and metrics dict for
        #either train or validation sets
        if dataset == 'train':
            losses = self.trn_loss_meter
            metric_dict = self.trn_metrics.metrics_dict
        elif dataset == 'valid':
            losses = self.val_loss_meter
            metric_dict = self.val_metrics.metrics_dict

        #log the last loss, using the dataset name as a prefix
        mlflow.log_metric(dataset + '_loss', losses.avg, step=step)

        #log all the metrics in our dict, using dataset as a prefix
        metrics = {}
        for k, v in metric_dict.items():
            values = v.meter.avg
            for class_name, val in zip(self.trn_metrics.class_names, values):
                metrics[dataset + '_' + class_name + '_' + k] = float(
                    val.item())

        mlflow.log_metrics(metrics, step=step)

    def train(self):
        """
        Defines a pytorch style training loop for the model withtqdm progress bar
        for each epoch and handles printing loss/metrics at the end of each epoch.
        
        epochs: Number of epochs to train model
        train_iters_per_epoch: Number of training iterations is each epoch. Reducing this 
        number will give more frequent updates but result in slower training time.
        
        Results:
        ----------
        
        After train_iters_per_epoch iterations are completed, it will evaluate the model
        on val_data if there is any, then prints loss and metrics for train and validation
        datasets.
        """

        #set the inner and outer training loop as either
        #iterations or epochs depending on our scheduler
        if self.config['lr_policy'] != 'MultiStep':
            last_epoch = self.scheduler.last_epoch + 1
            total_epochs = self.config['iters']
            iters_per_epoch = 1
            outer_loop = tqdm(range(last_epoch, total_epochs + 1),
                              file=sys.stdout,
                              initial=last_epoch,
                              total=total_epochs)
            inner_loop = range(iters_per_epoch)
        else:
            last_epoch = self.scheduler.last_epoch + 1
            total_epochs = self.config['epochs']
            iters_per_epoch = len(self.trn_data)
            outer_loop = range(last_epoch, total_epochs + 1)
            inner_loop = tqdm(range(iters_per_epoch), file=sys.stdout)

        #determine the epochs at which to print results
        eval_epochs = total_epochs // self.config['num_prints']
        save_epochs = total_epochs // self.config['num_save_checkpoints']

        #the cudnn.benchmark flag speeds up performance
        #when the model input size is constant. See:
        #https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        cudnn.benchmark = True

        #perform training over the outer and inner loops
        for epoch in outer_loop:
            for iteration in inner_loop:
                #load the next batch of training data
                images, masks = self.trn_data.load()

                #run the training iteration
                loss, output = self._train_1_iteration(images, masks)

                #record the loss and evaluate metrics
                self.trn_loss_meter.update(loss)
                self.trn_metrics.evaluate(output, masks)

            #when we're at an eval_epoch we want to print
            #the training results so far and then evaluate
            #the model on the validation data
            if epoch % eval_epochs == 0:
                #before printing results let's record everything in mlflow
                #(if we're using logging)
                if self.logging:
                    self.log_metrics(epoch, dataset='train')

                print('\n')  #print a new line to give space from progess bar
                print(f'train_loss: {self.trn_loss_meter.avg:.3f}')
                self.trn_loss_meter.reset()
                #prints and automatically resets the metric averages to 0
                self.trn_metrics.print()

                #run evaluation if we have validation data
                if self.val_data is not None:
                    #before evaluation we want to turn off cudnn
                    #benchmark because the input sizes of validation
                    #images are not necessarily constant
                    cudnn.benchmark = False
                    self.evaluate()

                    if self.logging:
                        self.log_metrics(epoch, dataset='valid')

                    print(
                        '\n')  #print a new line to give space from progess bar
                    print(f'valid_loss: {self.val_loss_meter.avg:.3f}')
                    self.val_loss_meter.reset()
                    #prints and automatically resets the metric averages to 0
                    self.val_metrics.print()

                    #turn cudnn.benchmark back on before returning to training
                    cudnn.benchmark = True

            #update the optimizer schedule
            self.scheduler.step()

            #the last step is to save the training state if
            #at a checkpoint
            if epoch % save_epochs == 0:
                self.save_state(epoch)

    def _train_1_iteration(self, images, masks):
        #run a training step
        self.model.train()
        self.optimizer.zero_grad()

        #forward pass
        output = self.model(images)
        loss = self.criterion(output, masks)

        #backward pass
        loss.backward()
        self.optimizer.step()

        #return the loss value and the output
        return loss.item(), output.detach()

    def evaluate(self):
        """
        Evaluation method used at the end of each epoch. Not intended to
        generate predictions for validation dataset, it only returns average loss
        and stores metrics for validaiton dataset.
        
        Use Validator class for generating masks on a dataset.
        """
        #set the model into eval mode
        self.model.eval()

        val_iter = DataFetcher(self.val_data)
        for _ in range(len(val_iter)):
            with torch.no_grad():
                #load batch of data
                images, masks = val_iter.load()
                output = self.model.eval()(images)
                loss = self.criterion(output, masks)
                self.val_loss_meter.update(loss.item())
                self.val_metrics.evaluate(output.detach(), masks)

        #loss and metrics are updated inplace, so there's nothing to return
        return None

    def save_state(self, epoch):
        """
        Saves the self.model state dict
        
        Arguments:
        ------------
        
        save_path: Path of .pt file for saving
        
        Example:
        ----------
        
        trainer = Trainer(...)
        trainer.save_model(model_path + 'new_model.pt')
        """

        #save the state together with the norms that we're using
        state = {
            'state_dict': self.model.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'norms': self.config['training_norms']
        }

        if self.logging:
            state['run_id'] = mlflow.active_run().info.run_id

        #the last step is to create the name of the file to save
        #the format is: name-of-experiment_pretraining_epoch.pth
        model_dir = self.config['model_dir']
        exp_name = self.config['experiment_name']
        pretraining = self.config['pretraining']
        ft_layer = self.config['finetune_layer']

        if self.config['lr_policy'] != 'MultiStep':
            total_epochs = self.config['iters']
        else:
            total_epochs = self.config['epochs']

        if os.path.isfile(pretraining):
            #this is slightly clunky, but it handles the case
            #of using custom pretrained weights from a file
            #usually there aren't any '.'s other than the file
            #extension
            pretraining = pretraining.split('/')[-2]  #.split('.')[0]

        save_path = os.path.join(
            model_dir,
            f'{exp_name}-{pretraining}_ft_{ft_layer}_epoch{epoch}_of_{total_epochs}.pth'
        )
        torch.save(state, save_path)
def train_model(train_dat, valid_dat, test_dat, model, device):

    model.train()
    # define the optimization
    criterion = CrossEntropyLoss()
    l2loss = MSELoss()
    # optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9)
    # optimizer = torch.optim.SGD([{'params': model.feature.parameters()},
    #                              {'params':model.fc.parameters(),'lr':10*args.lr}],
    #                             lr= args.lr,momentum=args.momentum,weight_decay=args.weight_clay)

    # optimizer = RMSprop(model.parameters(), lr=0.0001)
    # optimizer1 = AdamW([{'params': model.ddm.parameters(), 'lr': 0.001}], lr=0.001, weight_decay=0.01)
    optimizer1 = AdamW(model.ddm.parameters(), lr=0.001, weight_decay=0.01)
    optimizer2 = AdamW([{
        'params': model.feature.parameters(),
        'lr': 0.000005
    }, {
        'params': model.central.parameters(),
        'lr': 0.000005
    }],
                       lr=0.000005,
                       weight_decay=0.01)
    optimizer3 = AdamW([{
        'params': model.fc.parameters(),
        'lr': 0.000005
    }],
                       lr=0.000005,
                       weight_decay=0.01)

    # optimizer2 = RMSprop([{'params': model.feature.parameters()}, {'params': model.fc.parameters(), 'lr': 0.001}], lr=0.0001)
    # optimizer = Adam([{'params': model.ddm.parameters(), 'lr': 0.001}, {'params': model.feature.parameters()}, {'params': model.fc.parameters(), 'lr': 0.000005}], lr=0.000005)

    es = EarlyStopping(patience=5)

    if torch.cuda.is_available():
        model = model.cuda()

    # enumerate epochs for DA
    j = 0
    for epoch in range(n_epochs):
        j += 1
        # enumerate mini batches of src domain and target domain
        train_steps = len(train_dat)
        print("DA train_steps:", train_steps)

        epoch_loss = 0
        epoch_loss_l2 = 0
        epoch_loss_classifier = 0
        epoch_loss_classifier_tgt = 0
        epoch_loss_coral = 0

        i = 0
        for it, (src_data, src_label, tgt_data,
                 tgt_label) in enumerate(train_dat):
            # clear the gradients
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            optimizer3.zero_grad()

            # optimizer.zero_grad()
            if torch.cuda.is_available():
                tgt_data = tgt_data.to(device)
                src_data = src_data.to(device)
                src_label = src_label.to(device)
                tgt_label = tgt_label.to(device)

            # compute the model output
            # yhat = model(inputs)
            src_out, tgt_out, dm_out, centr1, centr2 = model(
                src_data, tgt_data)

            # calculate loss
            # loss = criterion(yhat, targets)
            # epoch_loss = loss
            loss_classifier = criterion(src_out, src_label)
            # print("src_label:")
            # print(src_label)
            loss_classifier_tgt = criterion(tgt_out, tgt_label) * 0.5
            # equalRate(src_label.cpu(), tgt_label.cpu())
            # print("tgt_label:")
            # print(tgt_label)
            # loss_coral = CORAL(src_out, tgt_out)
            loss_coral = CORAL(centr1, centr2)

            loss_l2 = lambda_l2 * l2loss(
                dm_out, src_data[:, 0:DDM_NUM + DIFFERECE_COL])
            # sum_loss = lambda_ * loss_coral + loss_classifier + lambda_l2 * loss_l2 + loss_classifier_tgt * 0.5
            sum_loss = lambda_ * loss_coral + loss_classifier
            epoch_loss += sum_loss.item()
            epoch_loss_l2 += loss_l2.item()
            epoch_loss_classifier += loss_classifier.item()
            epoch_loss_classifier_tgt += loss_classifier_tgt.item()
            epoch_loss_coral += loss_coral.item()

            # credit assignment
            # sum_loss.backward()
            sum_loss.backward(retain_graph=True)
            loss_classifier_tgt.backward(retain_graph=True)
            loss_l2.backward()
            # sum_loss.backward(retain_graph=True)
            # loss_l2.backward()
            # update model weights

            optimizer3.step()
            optimizer2.step()
            optimizer1.step()
            # optimizer.step()
            i = i + 1

        print(
            'DA Train Epoch: {:2d} [{:2d}/{:2d}]\t'
            'Lambda: {:.4f}, Class: {:.6f}, CORAL: {:.6f}, l2_loss: {:.6f}, Total_Loss: {:.6f}'
            .format(epoch, i + 1, train_steps, lambda_, loss_classifier.item(),
                    loss_coral.item(), loss_l2.item(), sum_loss.item()))

        print('DA Train ith Epoch %d result:' % epoch)
        # calculate train src accuracy
        train_acc = evaluate_model_src(train_dat, model, device)
        aggre_train_acc.append(train_acc)
        print('DA train_acc: %.3f' % train_acc)

        # calculate train tgt accuracy
        train_tgt_acc = evaluate_model_tgt(train_dat, model, device)
        aggre_train_tgt_acc.append(train_tgt_acc)
        print('DA train_tgt_acc: %.3f' % train_tgt_acc)

        # # calculate valid accuracy
        valid_acc = evaluate_model_tgt(valid_dat, model, device)
        aggre_valid_acc.append(valid_acc)
        print('DA valid_tgt_acc: %.3f' % valid_acc)

        # # calculate test accuracy
        test_acc = evaluate_model_tgt(test_dat, model, device)
        aggre_test_acc.append(test_acc)
        print('DA test_acc: %.3f' % test_acc)

        epoch_loss = epoch_loss / train_steps
        aggre_losses.append(epoch_loss)
        print(f'DA epoch: {j:3} sum loss: {epoch_loss:6.4f}')

        epoch_loss_l2 = epoch_loss_l2 / train_steps
        aggre_losses_l2.append(epoch_loss_l2)
        print(f'DA epoch: {j:3} l2 loss: {epoch_loss_l2:6.4f}')

        epoch_loss_classifier = epoch_loss_classifier / train_steps
        aggre_losses_classifier.append(epoch_loss_classifier)
        print(
            f'DA epoch: {j:3} classifier src loss: {epoch_loss_classifier:6.4f}'
        )

        epoch_loss_classifier_tgt = epoch_loss_classifier_tgt / train_steps
        aggre_losses_classifier_tgt.append(epoch_loss_classifier_tgt)
        print(
            f'DA epoch: {j:3} classifier tgt loss: {epoch_loss_classifier_tgt:6.4f}'
        )

        epoch_loss_coral = epoch_loss_coral / train_steps
        aggre_losses_coral.append(epoch_loss_coral)
        print(f'DA epoch: {j:3} coral loss: {epoch_loss_coral:6.4f}')

        # # calculate validate accuracy
        epoch_loss_valid, epoch_loss_l2_valid, epoch_loss_classifier_valid, epoch_loss_coral_valid, epoch_loss_classifier_valid_tgt = evaluate_model_stop(
            valid_dat, model,
            device)  # evalution on dev set (i.e., holdout from training)
        aggre_losses_valid.append(epoch_loss_valid)
        aggre_losses_l2_valid.append(epoch_loss_l2_valid)
        aggre_losses_classifier_valid.append(epoch_loss_classifier_valid)
        aggre_losses_classifier_valid_tgt.append(
            epoch_loss_classifier_valid_tgt)
        aggre_losses_coral_valid.append(epoch_loss_coral_valid)

        print(f'DA epoch: {j:3} valid sum loss: {epoch_loss_valid:6.4f}')
        print(f'DA epoch: {j:3} valid l2 loss: {epoch_loss_l2_valid:6.4f}')
        print(
            f'DA epoch: {j:3} valid classifier loss: {epoch_loss_classifier_valid:6.4f}'
        )
        print(
            f'DA epoch: {j:3} valid tgt classifier loss: {epoch_loss_classifier_valid_tgt:6.4f}'
        )
        print(
            f'DA epoch: {j:3} valid coral loss: {epoch_loss_coral_valid:6.4f}')

        if es.step(np.array(epoch_loss_classifier_valid_tgt)):
            print(f'Early Stopping Criteria Met!')
            break  # early stop criterion is met, we can stop now
コード例 #29
0
ファイル: distiller.py プロジェクト: raylin01/Deep-Learning
class Distiller:
    def __init__(self,
                 params: dict,
                 dataset: LmSeqsDataset,
                 token_probs: torch.tensor,
                 student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        self.mlm = params.mlm
        if self.mlm:
            logger.info(f'Using MLM loss for LM step.')
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info(f'Using CLM loss for LM step.')

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.: self.last_loss_mse = 0
        if self.alpha_cos > 0.: self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        if self.alpha_mse > 0.:
            self.mse_loss_fct = nn.MSELoss(reduction='sum')
        if self.alpha_cos > 0.:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
            {'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
        ]
        logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(self.student,
                                                          self.optimizer,
                                                          opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student,
                                                       device_ids=[params.local_rank],
                                                       output_device=params.local_rank,
                                                       find_unused_parameters=True)

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
            self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
            self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)

    def prepare_batch_mlm(self,
                          batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
        pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
        probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
        _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels

    def prepare_batch_clm(self,
                          batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
        clm_labels[~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

    def round_batch(self,
                    x: torch.tensor,
                    lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            if self.mlm:
                pad_id = self.params.special_tok_ids['pad_token']
            else:
                pad_id = self.params.special_tok_ids['unk_token']
            padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)

                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
                self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)

                iter_bar.update()
                iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
                                      'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
            iter_bar.close()

            if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')

    def step(self,
             input_ids: torch.tensor,
             attention_mask: torch.tensor,
             lm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        if self.mlm:
            s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask)     # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
        else:
            s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None)            # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None)           # (bs, seq_length, voc_size)
        assert s_logits.size() == t_logits.size()

        #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits)    # (bs, seq_lenth, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(s_logits)    # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(s_logits, mask)            # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))      # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, mask)            # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))      # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
                                   F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
        loss = self.alpha_ce*loss_ce

        if self.alpha_mlm > 0.:
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.:
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                                        shift_labels.view(-1))
            loss += self.alpha_clm * loss_clm

        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.:
            s_hidden_states = s_hidden_states[-1]                              # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]                              # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)     # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)
            
            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)        # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)                # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)        # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)                # (bs * seq_length, dim)
        
            target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self,
                 loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error('NaN detected')
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)

        self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
        if self.alpha_mlm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
        if self.alpha_clm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
        if self.alpha_mse > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
        if self.alpha_cos > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
        
        self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')

        if self.is_master:
            self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
            self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self,
                        checkpoint_name: str = 'checkpoint.pth'):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
コード例 #30
0
            input_ids = input_ids.cuda(device=device)
            attention_mask = attention_mask.cuda(device=device)
            token_type_ids = token_type_ids.cuda(device=device)
            rank_truth = rank_truth.cuda(device=device)
            bwd_loss, _ = \
                ranker(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=rank_truth)

            loss = bwd_loss + fwd_loss

            loss.backward()

            ranker_step += 1
            print("ranker_step %d | Training...\r" % ranker_step, end='')

            # if ranker_step % update_stepsize == 0:
            optimizer_rank.step()
            optimizer_rank.zero_grad()

        print("ranker_step %d | Validating..." % ranker_step)
        val_ranker_em = validate_ranker(ranker, fwd_dataloader_d_fg,
                                        bwd_dataloader_d_fg)

        if best_ranker_em <= val_ranker_em:
            torch.save(ranker.state_dict(),
                       'bestRoBERTaRanker_full_dev_em.pth')
            best_ranker_em = val_ranker_em
            print("ranker_em %s" % best_ranker_em)

    pdb.set_trace()

    ranker.load_state_dict(torch.load('bestRoBERTaRanker_full_dev_em.pth'))