示例#1
0
    def run_epoch(self, epoch, training):
        self.model.train(training)

        if training:
            description = '[Train]'
            dataset = self.trainData
            shuffle = True
        else:
            description = '[Valid]'
            dataset = self.validData
            shuffle = False

        # dataloader for train and valid
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=dataset.collate_fn,
        )

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description)
        loss = 0
        accuracy = Accuracy()
        for i, (
                x,
                y,
        ) in trange:  # (x,y) = 128*128
            o_labels, batch_loss = self.run_iter(x, y)
            if training:
                self.opt.zero_grad()  # reset gradient to 0
                batch_loss.backward()  # calculate gradient
                self.opt.step()  # update parameter by gradient

            loss += batch_loss.item()  # .item() to get python number in Tensor
            accuracy.update(o_labels.cpu(), y)

            trange.set_postfix(accuracy=accuracy.print_score(),
                               loss=loss / (i + 1))

        if training:
            self.history['train'].append({
                'accuracy': accuracy.get_score(),
                'loss': loss / len(trange)
            })
            self.scheduler.step()
        else:
            self.history['valid'].append({
                'accuracy': accuracy.get_score(),
                'loss': loss / len(trange)
            })
            if loss < self.min_loss:
                self.save_best_model(epoch)
            self.min_loss = loss

        self.save_hist()
示例#2
0
    def run_epoch(self, epoch, source_dataloader, target_dataloader, lamb):

        trange = tqdm(zip(source_dataloader, target_dataloader),
                      total=len(source_dataloader),
                      desc=f'[epoch {epoch}]')

        total_D_loss, total_F_loss = 0.0, 0.0
        acc = Accuracy()
        for i, ((source_data, source_label), (target_data,
                                              _)) in enumerate(trange):
            source_data = source_data.to(self.device)
            source_label = source_label.to(self.device)
            target_data = target_data.to(self.device)

            # =========== Preprocess =================
            # mean/var of source and target datas are different, so we put them together for properly batch_norm
            mixed_data = torch.cat([source_data, target_data], dim=0)
            domain_label = torch.zeros(
                [source_data.shape[0] + target_data.shape[0],
                 1]).to(self.device)
            domain_label[:source_data.shape[
                0]] = 1  # source data label=1, target data lebel=0
            feature = self.feature_extractor(mixed_data)

            # =========== Step 1 : Train Domain Classifier (fix feature extractor by feature.detach()) =================
            domain_logits = self.domain_classifier(feature.detach())
            loss = self.domain_criterion(domain_logits, domain_label)
            total_D_loss += loss.item()
            loss.backward()
            self.optimizer_D.step()

            # =========== Step 2: Train Feature Extractor and Label Predictor =================
            class_logits = self.label_predictor(feature[:source_data.shape[0]])
            domain_logits = self.domain_classifier(feature)
            loss = self.class_criterion(
                class_logits, source_label) - lamb * self.domain_criterion(
                    domain_logits, domain_label)
            total_F_loss += loss.item()
            loss.backward()
            self.optimizer_F.step()
            self.optimizer_C.step()

            self.optimizer_D.zero_grad()
            self.optimizer_F.zero_grad()
            self.optimizer_C.zero_grad()

            acc.update(class_logits, source_label)

            trange.set_postfix(D_loss=total_D_loss / (i + 1),
                               F_loss=total_F_loss / (i + 1),
                               acc=acc.print_score())

        self.history['d_loss'].append(total_D_loss / len(trange))
        self.history['f_loss'].append(total_F_loss / len(trange))
        self.history['acc'].append(acc.get_score())
        self.save_hist()

        self.save_model()
    def run_epoch(self, epoch, training):
        self.model.train(training)

        if training:
            description = 'Train'
            dataset = self.trainData
            shuffle = True
        else:
            description = 'Valid'
            dataset = self.validData
            shuffle = False
        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.batch_size,
                                shuffle=shuffle,
                                collate_fn=dataset.collate_fn,
                                num_workers=4)

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description,
                      ascii=True)

        f_loss = 0
        l_loss = 0
        accuracy = Accuracy()

        for i, (x, missing, y) in trange:
            o_labels, batch_f_loss, batch_l_loss = self.run_iter(x, missing, y)
            batch_loss = batch_f_loss + batch_l_loss

            if training:
                self.opt.zero_grad()
                batch_loss.backward()
                self.opt.step()

            f_loss += batch_f_loss.item()
            l_loss += batch_l_loss.item()
            accuracy.update(o_labels.cpu(), y)

            trange.set_postfix(accuracy=accuracy.print_score(),
                               f_loss=f_loss / (i + 1),
                               l_loss=l_loss / (i + 1))

        if training:
            self.history['train'].append({
                'accuracy': accuracy.get_score(),
                'loss': f_loss / len(trange)
            })
            self.scheduler.step()
        else:
            self.history['valid'].append({
                'accuracy': accuracy.get_score(),
                'loss': f_loss / len(trange)
            })
def training(args, train_loader, valid_loader, model, optimizer, device):
    train_metrics = Accuracy()
    best_valid_acc = 0
    total_iter = 0
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(args.epochs):
        train_trange = tqdm(enumerate(train_loader),
                            total=len(train_loader),
                            desc='training')
        train_loss = 0
        train_metrics.reset()
        for i, batch in train_trange:
            model.train()
            prob = run_iter(batch, model, device, training=True)
            answer = batch['label'].to(device)
            loss = criterion(prob, answer)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_iter += 1
            train_loss += loss.item()
            train_metrics.update(prob, answer)
            train_trange.set_postfix(
                loss=train_loss / (i + 1),
                **{train_metrics.name: train_metrics.print_score()})

            if total_iter % args.eval_steps == 0:
                valid_acc = testing(valid_loader, model, device, valid=True)
                if valid_acc > best_valid_acc:
                    best_valid_acc = valid_acc
                    torch.save(
                        model,
                        os.path.join(
                            args.model_dir,
                            'fine-tuned_bert_{}.pkl'.format(args.seed)))

    # Final validation
    valid_acc = testing(valid_loader, model, device, valid=True)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        torch.save(
            model,
            os.path.join(args.model_dir,
                         'fine-tuned_bert_{}.pkl'.format(args.seed)))
    print('Best Valid Accuracy:{}'.format(best_valid_acc))
def testing(dataloader, model, device, valid):
    metrics = Accuracy()
    criterion = torch.nn.CrossEntropyLoss()
    trange = tqdm(enumerate(dataloader),
                  total=len(dataloader),
                  desc='validation' if valid else 'testing')
    model.eval()
    total_loss = 0
    metrics.reset()
    for k, batch in trange:
        model.eval()
        prob = run_iter(batch, model, device, training=False)
        answer = batch['label'].to(device)
        loss = criterion(prob, batch['label'].to(device))
        total_loss += loss.item()
        metrics.update(prob, answer)
        trange.set_postfix(loss=total_loss / (k + 1),
                           **{metrics.name: metrics.print_score()})
    acc = metrics.match / metrics.n
    return acc
示例#6
0
    def run_epoch(self, epoch, dataset, training, desc=''):
        self.model.train(training)
        shuffle = training

        dataloader = DataLoader(dataset, self.batch_size, shuffle=shuffle)
        trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=desc)
        loss = 0
        acc = Accuracy()
        for i, (imgs, labels) in trange:  # (b, 3, 128, 128), (b, 1)
            labels = labels.view(-1)  # (b,)
            o_labels, batch_loss = self.run_iters(imgs, labels)

            if training:
                batch_loss /= self.accum_steps
                batch_loss.backward()
                if (i + 1) % self.accum_steps == 0:
                    self.opt.step()
                    self.opt.zero_grad()
                batch_loss *= self.accum_steps

            loss += batch_loss.item()
            acc.update(o_labels.cpu(), labels)

            trange.set_postfix(loss=loss / (i + 1), acc=acc.print_score())

        if training:
            self.history['train'].append({
                'loss': loss / len(trange),
                'acc': acc.get_score()
            })
            self.scheduler.step()
        else:
            self.history['valid'].append({
                'loss': loss / len(trange),
                'acc': acc.get_score()
            })
            if loss < self.best_score:
                self.save_best()
        self.save_hist()
示例#7
0
    def run_epoch(self, epoch, dataloader):
        self.feature_extractor.train(True)
        self.label_predictor.train(True)

        trange = tqdm(dataloader,
                      total=len(dataloader),
                      desc=f'[epoch {epoch}]')

        total_loss = 0
        acc = Accuracy()
        for i, (target_data, target_label) in enumerate(trange):  # (b,1,32,32)
            target_data = target_data.to(self.device)
            target_label = target_label.view(-1).to(self.device)  # (b)

            feature = self.feature_extractor(target_data)  # (b, 512)
            class_logits = self.label_predictor(feature)  # (b, 10)

            loss = self.class_criterion(class_logits, target_label)
            total_loss += loss.item()
            loss.backward()
            self.optimizer_F.step()
            self.optimizer_C.step()

            self.optimizer_F.zero_grad()
            self.optimizer_C.zero_grad()

            acc.update(class_logits, target_label)

            trange.set_postfix(loss=total_loss / (i + 1),
                               acc=acc.print_score())

        self.history['loss'].append(total_loss / len(trange))
        self.history['acc'].append(acc.get_score())
        self.save_hist()

        self.save_model()
示例#8
0
    def run_epoch(self, epoch, training, stage1):
        if stage1:
            self.model1.train(training)
        else:
            self.model1.train(False)
            self.model2.train(training)

        if training:
            description = '[Stage1 Train]' if stage1 else '[Stage2 Train]'
            dataset = self.trainData
            shuffle = True
        else:
            description = '[Stage1 Valid]' if stage1 else '[Stage2 Valid]'
            dataset = self.validData
            shuffle = False

        # dataloader for train and valid
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=dataset.collate_fn,
        )

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description)
        loss = 0
        loss2 = 0
        accuracy = Accuracy()

        if stage1:
            for i, (x, y, miss) in trange:  # (x,y) = b*b
                pdb()
                o_f1, batch_loss = self.run_iter_stage1(x, miss)

                if training:
                    self.opt.zero_grad()  # reset gradient to 0
                    batch_loss.backward()  # calculate gradient
                    self.opt.step()  # update parameter by gradient

                loss += batch_loss.item(
                )  # .item() to get python number in Tensor
                trange.set_postfix(loss=loss / (i + 1))

        else:
            for i, (x, y, miss) in trange:  # (x,y) = b*b
                o_labels, batch_loss, missing_loss = self.run_iter_stage2(
                    x, miss, y)  # x=(256, 8),  y=(256)
                loss2 += missing_loss.item()

                if training:
                    self.opt.zero_grad()  # reset gradient to 0
                    batch_loss.backward()  # calculate gradient
                    self.opt.step()  # update parameter by gradient

                loss += batch_loss.item(
                )  #.item() to get python number in Tensor
                accuracy.update(o_labels.cpu(), y)

                trange.set_postfix(accuracy=accuracy.print_score(),
                                   loss=loss / (i + 1),
                                   missing_loss=loss2 / (i + 1))
    def run_epoch(self, epoch, training):
        self.model.train(training)
        self.generator.train(training)
        self.discriminator.train(training)

        if training:
            description = 'Train'
            dataset = self.trainData
            shuffle = True
        else:
            description = 'Valid'
            dataset = self.validData
            shuffle = False
        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.batch_size,
                                shuffle=shuffle,
                                collate_fn=dataset.collate_fn,
                                num_workers=4)

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description,
                      ascii=True)

        g_loss = 0
        d_loss = 0
        loss = 0
        accuracy = Accuracy()

        for i, (features, real_missing, labels) in trange:

            features = features.to(self.device)  # (batch, 11)
            real_missing = real_missing.to(self.device)  # (batch, 3)
            labels = labels.to(self.device)  # (batch, 1)
            batch_size = features.shape[0]

            if training:
                rand = torch.rand((batch_size, 11)).to(self.device) - 0.5
                std = features.std(dim=1)
                noise = rand * std.unsqueeze(1)
                features += noise

            # Adversarial ground truths
            valid = torch.FloatTensor(batch_size, 1).fill_(1.0).to(
                self.device)  # (batch, 1)
            fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(
                self.device)  # (batch, 1)

            # ---------------------
            #  Train Discriminator
            # ---------------------

            if i % 10 < 5 or not training:
                real_pred = self.discriminator(real_missing)
                d_real_loss = self.criterion(real_pred, valid)

                fake_missing = self.generator(features.detach())
                fake_pred = self.discriminator(fake_missing)
                d_fake_loss = self.criterion(fake_pred, fake)
                batch_d_loss = (d_real_loss + d_fake_loss)

                if training:
                    self.opt_D.zero_grad()
                    batch_d_loss.backward()
                    self.opt_D.step()
                d_loss += batch_d_loss.item()

            # -----------------
            #  Train Generator
            # -----------------

            if i % 10 >= 5 or not training:
                gen_missing = self.generator(features.detach())
                validity = self.discriminator(gen_missing)
                batch_g_loss = self.criterion(validity, valid)

                if training:
                    self.opt_G.zero_grad()
                    batch_g_loss.backward()
                    self.opt_G.step()
                g_loss += batch_g_loss.item()

                # ------------------
                #  Train Classifier
                # ------------------

                gen_missing = self.generator(features.detach())
                all_features = torch.cat((features, gen_missing), dim=1)
                o_labels = self.model(all_features)
                batch_loss = self.criterion(o_labels, labels)

                if training:
                    self.opt.zero_grad()
                    batch_loss.backward()
                    self.opt.step()
                loss += batch_loss.item()
                accuracy.update(o_labels, labels)

                trange.set_postfix(accuracy=accuracy.print_score(),
                                   g_loss=g_loss / (i + 1),
                                   d_loss=d_loss / (i + 1),
                                   loss=loss / (i + 1))

        if training:
            self.history['train'].append({
                'accuracy': accuracy.get_score(),
                'g_loss': g_loss / len(trange),
                'd_loss': d_loss / len(trange),
                'loss': loss / len(trange)
            })
            self.scheduler.step()
            self.scheduler_G.step()
            self.scheduler_D.step()
        else:
            self.history['valid'].append({
                'accuracy': accuracy.get_score(),
                'g_loss': g_loss / len(trange),
                'd_loss': d_loss / len(trange),
                'loss': loss / len(trange)
            })