Пример #1
0
class Trainer():
    def __init__(self,
                 trainData,
                 validData,
                 hidden_size,
                 device,
                 model_dir="model"):
        self.history = {'train': [], 'valid': []}
        self.trainData = trainData
        self.validData = validData
        self.classficationA = SimpleNet(input_size=8,
                                        output_size=12,
                                        hidden_size=hidden_size).to(device)
        self.classficationB = SimpleNet(input_size=9,
                                        output_size=12,
                                        hidden_size=hidden_size).to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
        self.opt_C_A = torch.optim.Adam(self.classficationA.parameters(),
                                        lr=1e-4)
        self.opt_C_B = torch.optim.Adam(self.classficationB.parameters(),
                                        lr=1e-4)
        self.device = device
        self.model_dir = model_dir
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_val = 0.0

    def run_epoch(self, epoch, training):
        self.classficationA.train(training)
        self.classficationB.train(training)

        if training:
            description = 'Train'
            dataset = self.trainData
            shuffle = True
        else:
            description = 'Valid'
            dataset = self.validData
            shuffle = False
        dataloader = DataLoader(dataset=dataset,
                                batch_size=256,
                                shuffle=shuffle,
                                collate_fn=dataset.collate_fn,
                                num_workers=4)

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description)

        mse_loss = 0
        lossA = 0
        lossB = 0
        accA = accuracy()
        accB = accuracy()

        for i, (ft, missing_ft, labels) in trange:
            ft = ft.to(self.device)
            missing_ft = missing_ft.to(self.device)
            all_ft = torch.cat([ft, missing_ft], dim=1)
            labels = labels.to(self.device)

            # ------------------
            #  Train ClassifierA
            # ------------------

            missing_out, missing_hidden_out = self.classficationA(ft)
            all_out, all_hidden_out = self.classficationB(all_ft)
            batch_loss = self.criterion(missing_out, labels)
            batch_mse_loss = 0
            for missing_hidden, all_hidden in zip(missing_hidden_out,
                                                  all_hidden_out):
                batch_mse_loss += self.mse_loss(missing_hidden, all_hidden)
            mse_loss += batch_mse_loss.item()

            if training:
                self.opt_C_A.zero_grad()
                (batch_mse_loss + batch_loss).backward()
                self.opt_C_A.step()
            lossA += batch_loss.item()
            accA.update(missing_out, labels)

            # ------------------
            #  Train ClassifierB
            # ------------------

            all_out, _ = self.classficationB(all_ft)
            batch_loss = self.criterion(all_out, labels)
            if training:
                self.opt_C_B.zero_grad()
                batch_loss.backward()
                self.opt_C_B.step()
            lossB += batch_loss.item()
            accB.update(all_out, labels)

            trange.set_postfix(accA=accA.print_score(),
                               accB=accB.print_score(),
                               lossA=lossA / (i + 1),
                               lossB=lossB / (i + 1),
                               mseLoss=mse_loss / (i + 1))
        if training:
            self.history['train'].append({
                'accA': accA.get_score(),
                'accB': accB.get_score(),
                'lossA': lossA / len(trange),
                'lossB': lossB / len(trange),
                'mseLoss': mse_loss / len(trange)
            })
            self.save_hist()

        else:
            self.history['valid'].append({
                'accA': accA.get_score(),
                'accB': accB.get_score(),
                'lossA': lossA / len(trange),
                'lossB': lossB / len(trange),
                'mseLoss': mse_loss / len(trange)
            })
            self.save_hist()
            if self.best_val < accA.get_score():
                self.best_val = accA.get_score()
                self.save_best(epoch)

    def save_best(self, epoch):
        torch.save(
            {
                'classficationA': self.classficationA.state_dict(),
                'classficationB': self.classficationB.state_dict(),
            }, self.model_dir + '/model.pkl.' + str(epoch))

    def save_hist(self):
        with open(self.model_dir + '/history.json', 'w') as f:
            json.dump(self.history, f, indent=4)
Пример #2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--arch',
                        default="model",
                        help='architecture (model_dir)')
    parser.add_argument('--do_train', action='store_true')
    parser.add_argument('--do_predict', action='store_true')
    parser.add_argument('--do_plot', action='store_true')
    parser.add_argument('--hidden_size', default=256, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--max_epoch', default=10000, type=int)
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument('--step_lr', default=0.5, type=float)
    parser.add_argument('--cuda', default=0, type=int)
    parser.add_argument('--ckpt',
                        type=int,
                        help='load pre-trained model epoch')
    args = parser.parse_args()

    if args.do_train:

        dataset = pd.read_csv("../../data/train.csv")
        dataset.drop("Id", axis=1, inplace=True)
        train_set, valid_set = train_test_split(dataset,
                                                test_size=0.1,
                                                random_state=73)
        feature_for_training = ["F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9"]
        feature_for_prediction = ["F1"]

        train = preprocess_samples(train_set, feature_for_training,
                                   feature_for_prediction)
        valid = preprocess_samples(valid_set, feature_for_training,
                                   feature_for_prediction)

        trainData = FeatureDataset(train)
        validData = FeatureDataset(valid)

        device = torch.device(
            'cuda:%d' % args.cuda if torch.cuda.is_available() else 'cpu')
        max_epoch = args.max_epoch
        trainer = Trainer(device, trainData, validData, args)

        for epoch in range(1, max_epoch + 1):
            print('Epoch: {}'.format(epoch))
            trainer.run_epoch(epoch, True)
            trainer.run_epoch(epoch, False)

    if args.do_predict:

        dataset = pd.read_csv("../../data/test.csv")
        dataset.drop("Id", axis=1, inplace=True)
        feature_for_testing = ["F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9"]
        test = preprocess_samples(dataset, feature_for_testing)

        testData = FeatureDataset(test)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = SimpleNet(input_size=9,
                          output_size=12,
                          hidden_size=args.hidden_size)
        model.load_state_dict(
            torch.load('%s/model.pkl.%d' % (args.arch, args.ckpt)))
        model.train(False)
        model.to(device)
        dataloader = DataLoader(dataset=testData,
                                batch_size=args.batch_size,
                                shuffle=False,
                                collate_fn=testData.collate_fn,
                                num_workers=4)
        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc='Predict')
        prediction = []
        for i, (ft, _, y) in trange:
            b = ft.shape[0]
            missing_ft = torch.zeros(b, 1)
            all_ft = torch.cat([missing_ft, ft], dim=1)
            o_labels, _ = model(all_ft.to(device))
            o_labels = torch.argmax(o_labels, axis=1)
            prediction.append(o_labels.to('cpu').numpy().tolist())

        prediction = sum(prediction, [])
        SubmitGenerator(prediction, "../../data/sampleSubmission.csv")

    if args.do_plot:
        plot_history("{file}/history.json".format(file=args.arch))
class Trainer():
    def __init__(self,
                 trainData,
                 validData,
                 hidden_size,
                 device,
                 model_dir="model"):
        self.history = {'train': [], 'valid': []}
        self.trainData = trainData
        self.validData = validData
        self.generator = Generator(input_size=8,
                                   output_size=1,
                                   hidden_size=hidden_size).to(device)
        self.discriminator = Discriminator(input_size=1,
                                           output_size=1,
                                           hidden_size=hidden_size).to(device)
        self.classfication = SimpleNet(input_size=9,
                                       output_size=12,
                                       hidden_size=hidden_size).to(device)
        self.adversarial_loss = nn.BCEWithLogitsLoss()
        self.criterion = nn.CrossEntropyLoss()
        self.opt_G = torch.optim.Adam(self.generator.parameters(), lr=1e-4)
        self.opt_D = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)
        self.opt_C = torch.optim.Adam(self.classfication.parameters(), lr=1e-4)
        self.device = device
        self.model_dir = model_dir
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_val = 0.0

    def run_epoch(self, epoch, training):
        self.generator.train(training)
        self.discriminator.train(training)
        self.classfication.train(training)

        if training:
            description = 'Train'
            dataset = self.trainData
            shuffle = True
        else:
            description = 'Valid'
            dataset = self.validData
            shuffle = False
        dataloader = DataLoader(dataset=dataset,
                                batch_size=256,
                                shuffle=shuffle,
                                collate_fn=dataset.collate_fn,
                                num_workers=4)

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description)

        g_loss = 0
        d_loss = 0
        loss = 0
        acc = accuracy()

        for i, (ft, missing_ft, labels) in trange:
            ft = ft.to(self.device)
            missing_ft = missing_ft.to(self.device)
            labels = labels.to(self.device)
            batch_size = ft.shape[0]
            true = Variable(torch.FloatTensor(batch_size, 1).fill_(1.0),
                            requires_grad=False).to(self.device)  # (batch, 1)
            fake = Variable(torch.FloatTensor(batch_size, 1).fill_(0.0),
                            requires_grad=False).to(self.device)  # (batch, 1)

            # -----------------
            #  Train Generator
            # -----------------

            gen_missing = self.generator(ft.detach())
            validity = self.discriminator(gen_missing)
            batch_g_loss = self.adversarial_loss(validity, true)

            if training:
                self.opt_G.zero_grad()
                batch_g_loss.backward()
                self.opt_G.step()
            g_loss += batch_g_loss.item()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            real_pred = self.discriminator(missing_ft)
            d_real_loss = self.adversarial_loss(real_pred, true)

            fake_missing = self.generator(ft.detach())
            fake_pred = self.discriminator(fake_missing)
            d_fake_loss = self.adversarial_loss(fake_pred, fake)
            batch_d_loss = (d_real_loss + d_fake_loss) / 2

            if training:
                self.opt_D.zero_grad()
                batch_d_loss.backward()
                self.opt_D.step()
            d_loss += batch_d_loss.item()

            # ------------------
            #  Train Classifier
            # ------------------

            gen_missing = self.generator(ft.detach())
            all_features = torch.cat((ft, gen_missing), dim=1)
            o_labels = self.classfication(all_features)
            batch_loss = self.criterion(o_labels, labels)
            if training:
                self.opt_C.zero_grad()
                batch_loss.backward()
                self.opt_C.step()
            loss += batch_loss.item()

            acc.update(o_labels, labels)

            trange.set_postfix(acc=acc.print_score(),
                               g_loss=g_loss / (i + 1),
                               d_loss=d_loss / (i + 1),
                               loss=loss / (i + 1))

        if training:
            self.history['train'].append({
                'acc': acc.get_score(),
                'g_loss': g_loss / len(trange),
                'd_loss': d_loss / len(trange),
                'loss': loss / len(trange)
            })
            self.save_hist()

        else:
            self.history['valid'].append({
                'acc': acc.get_score(),
                'g_loss': g_loss / len(trange),
                'd_loss': d_loss / len(trange),
                'loss': loss / len(trange)
            })
            self.save_hist()
            if self.best_val < acc.get_score():
                self.best_val = acc.get_score()
                self.save_best(epoch)

    def save_best(self, epoch):
        torch.save(
            {
                'cls': self.classfication.state_dict(),
                'generator': self.generator.state_dict(),
                'discriminator': self.discriminator.state_dict()
            }, self.model_dir + '/model.pkl.' + str(epoch))

    def save_hist(self):
        with open(self.model_dir + '/history.json', 'w') as f:
            json.dump(self.history, f, indent=4)
class Trainer():
    def __init__(self, device, trainData, validData, args):
        self.device = device
        self.history = {'train': [], 'valid': []}
        self.trainData = trainData
        self.validData = validData

        self.fadding_model = SimpleNet(input_size=9,
                                       output_size=12,
                                       hidden_size=args.hidden_size).to(device)
        self.fadding_model.load_state_dict(
            torch.load("model0.33/model.pkl.904"))
        self.fixed_model = SimpleNet(input_size=9,
                                     output_size=12,
                                     hidden_size=args.hidden_size).to(device)
        self.fixed_model.load_state_dict(torch.load("model0.33/model.pkl.904"))

        self.criteria = torch.nn.MSELoss()
        self.opt = torch.optim.AdamW(self.fadding_model.parameters(),
                                     lr=8e-5,
                                     weight_decay=9e-3)
        # self.scheduler = scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=200, gamma=args.step_lr)
        self.batch_size = args.batch_size
        self.model_dir = args.arch
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_val = 0.0

    def run_epoch(self, epoch, training):
        self.fadding_model.train(training)
        self.fixed_model.train(False)

        if training:
            description = 'Train'
            dataset = self.trainData
            shuffle = True
        else:
            description = 'Valid'
            dataset = self.validData
            shuffle = False
        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.batch_size,
                                shuffle=shuffle,
                                collate_fn=dataset.collate_fn,
                                num_workers=4)

        trange = tqdm(enumerate(dataloader),
                      total=len(dataloader),
                      desc=description)

        loss = 0
        acc_fadding = accuracy()
        acc_fixed = accuracy()

        for i, (ft, missing_ft, labels) in trange:
            ft = ft.to(self.device)
            missing_ft = missing_ft.to(self.device)
            labels = labels.to(self.device)
            missing_fadding_ft = missing_ft * (0.9**((epoch * 100)**(1 / 2)))
            missing_0_ft = missing_ft * 0

            fadding_ft = torch.cat([missing_fadding_ft, ft], dim=1)
            zero_ft = torch.cat([missing_0_ft, ft], dim=1)
            raw_ft = torch.cat([missing_ft, ft], dim=1)

            fadding_out, fadding_hiddens = self.fadding_model(fadding_ft)
            zero_out, _ = self.fadding_model(zero_ft)
            raw_out, raw_hiddens = self.fixed_model(raw_ft)

            batch_loss = 0
            for raw_hidden, fadding_hidden in zip(raw_hiddens,
                                                  fadding_hiddens):
                batch_loss += self.criteria(raw_hidden, fadding_hidden)

            batch_loss += self.criteria(raw_out, fadding_out)

            if training:
                self.opt.zero_grad()
                batch_loss.backward()
                self.opt.step()

            loss += batch_loss.item()
            acc_fadding.update(fadding_out, labels)
            acc_fixed.update(zero_out, labels)

            trange.set_postfix(loss=loss / (i + 1),
                               acc_fadding=acc_fadding.print_score(),
                               acc_fixed=acc_fixed.print_score())

        # self.scheduler.step()

        if training:
            self.history['train'].append({
                'acc-fadding': acc_fadding.get_score(),
                'acc_fixed': acc_fixed.get_score(),
                'loss': loss / len(trange)
            })
            self.save_hist()
        else:
            self.history['valid'].append({
                'acc-fadding': acc_fadding.get_score(),
                'acc_fixed': acc_fixed.get_score(),
                'loss': loss / len(trange)
            })
            self.save_hist()
            if acc_fixed.get_score() > self.best_val:
                self.best_val = acc_fixed.get_score()
                self.save_best(epoch)

    def run_iter(self, x, y):
        features = x.to(self.device)
        labels = y.to(self.device)
        o_labels, hiddens = self.model(features)
        l_loss = self.criteria(o_labels, labels)
        return o_labels, l_loss

    def save_best(self, epoch):
        torch.save(self.fadding_model.state_dict(),
                   self.model_dir + '/model.pkl.' + str(epoch))

    def save_hist(self):
        with open(self.model_dir + '/history.json', 'w') as f:
            json.dump(self.history, f, indent=4)
class Trainer():
    def __init__(self, device, trainData, validData, args):
        self.device = device
        self.history = {'train': [], 'valid': []}
        self.trainData = trainData
        self.validData = validData
        self.model = SimpleNet(input_size=9, output_size=12, hidden_size=args.hidden_size).to(device)
        self.criteria = torch.nn.CrossEntropyLoss()
        self.opt = torch.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=3.3e-1)
        self.scheduler = scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=200, gamma=args.step_lr)
        self.batch_size = args.batch_size
        self.model_dir = args.arch
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_val = 0.0

    def run_epoch(self, epoch, training):
        self.model.train(training)

        if training:
            description = 'Train'
            dataset = self.trainData
            shuffle = True
        else:
            description = 'Valid'
            dataset = self.validData
            shuffle = False
        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.batch_size,
                                shuffle=shuffle,
                                collate_fn=dataset.collate_fn,
                                num_workers=4)

        trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=description)

        loss = 0
        acc = accuracy()

        for i, (x, _, y) in trange:
            o_labels, batch_loss = self.run_iter(x, y)
            if training:
                self.opt.zero_grad()
                batch_loss.backward()
                self.opt.step()

            loss += batch_loss.item()
            acc.update(o_labels.cpu(), y)

            trange.set_postfix(
                loss=loss / (i + 1), acc=acc.print_score())
        if training:
            self.history['train'].append({'acc': acc.get_score(), 'loss': loss / len(trange)})
            self.save_hist()
        else:
            self.history['valid'].append({'acc': acc.get_score(), 'loss': loss / len(trange)})
            self.save_hist()
            if acc.get_score() > self.best_val:
                self.best_val = acc.get_score()
                self.save_best(epoch)

    def run_iter(self, x, y):
        features = x.to(self.device)
        labels = y.to(self.device)
        o_labels = self.model(features)
        l_loss = self.criteria(o_labels, labels)
        return o_labels, l_loss

    def save_best(self, epoch):
        torch.save(self.model.state_dict(), self.model_dir + '/model.pkl.'+str(epoch))

    def save_hist(self):
        with open(self.model_dir + '/history.json', 'w') as f:
            json.dump(self.history, f, indent=4)