def start(self, train_loader, train_set, valid_set=None, valid_loader=None):
            self.train_num_batches = math.ceil(train_set.num_images / float(self.cf.train_batch_size))
            self.val_num_batches = 0 if valid_set is None else math.ceil(
                valid_set.num_images / float(self.cf.valid_batch_size))

            # Define early stopping control
            if self.cf.early_stopping:
                early_stopping = EarlyStopping(self.cf)
            else:
                early_stopping = None

            # Train process
            for epoch in tqdm(range(self.curr_epoch, self.cf.epochs + 1), desc='Training', file=sys.stdout):
                # Shuffle train data
                train_set.update_indexes()

                # Initialize logger
                self.logger_stats.write('\n\t ------ Epoch: ' + str(epoch) + ' ------ \n')

                # Initialize stats
                self.stats.epoch = epoch
                self.train_loss = AverageMeter()
                self.confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes))

                # Train epoch
                self.training_loop(epoch, train_loader)

                # Save stats
                self.stats.train.conf_m = self.confm_list
                self.compute_stats(self.confm_list, self.train_loss)
                self.save_stats_epoch(epoch)
                self.logger_stats.write_stat(self.stats.train, epoch,
                                             os.path.join(self.cf.train_json_path,
                                                          'train_epoch_' + str(epoch) + '.json'))

                # Validate epoch
                self.validate_epoch(valid_set, valid_loader, early_stopping, epoch)

                # Update scheduler
                if self.model.scheduler is not None:
                    self.model.scheduler.step(self.stats.val.loss)

                # Saving model if score improvement
                new_best = self.model.save(self.stats)
                if new_best:
                    self.logger_stats.write_best_stats(self.stats, epoch, self.cf.best_json_file)

                if self.stop:
                    return

                    # Save model without training
            if self.cf.epochs == 0:
                self.model.save_model()
Esempio n. 2
0
    def train(self, setting):
        """Training Function.

    Args:
        setting: Name used to save the model

    Returns:
        model: Trained model
    """

        # Load different datasets
        train_loader = self._get_data(flag='train')
        vali_loader = self._get_data(flag='val')
        test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True)

        # Setting optimizer and loss functions
        model_optim = self._select_optimizer()
        criterion = nn.MSELoss()

        all_training_loss = []
        all_validation_loss = []

        # Training Loop
        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
                if self.model_type == 'SDT':
                    (pred, panelty), true = self._process_one_batch(
                        batch_x, batch_y)
                    loss = criterion(pred, true) + panelty
                else:
                    pred, true = self._process_one_batch(batch_x, batch_y)
                    loss = criterion(pred, true)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print(
                        '\titers: {0}/{1}, epoch: {2} | loss: {3:.7f}'.format(
                            i + 1, train_steps, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                        speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                model_optim.step()

            print('Epoch: {} cost time: {}'.format(epoch + 1,
                                                   time.time() - epoch_time))
            train_loss = np.average(train_loss)
            all_training_loss.append(train_loss)
            vali_loss = self.vali(vali_loader, criterion)
            all_validation_loss.append(vali_loss)
            test_loss = self.vali(test_loader, criterion)

            print(
                'Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}'
                .format(epoch + 1, train_steps, train_loss, vali_loss,
                        test_loss))
            early_stopping(vali_loss, self.model, path)

            # Plotting train and validation loss
            if ((epoch + 1) % 5 == 0 and self.args.plot):
                check_folder = os.path.isdir(self.args.plot_dir)

                # If folder doesn't exist, then create it.
                if not check_folder:
                    os.makedirs(self.args.plot_dir)

                plt.figure()
                plt.plot(all_training_loss, label='train loss')
                plt.plot(all_validation_loss, label='Val loss')
                plt.legend()
                plt.savefig(self.args.plot_dir + setting + '.png')
                plt.show()
                plt.close()

            # If ran out of patience stop training
            if early_stopping.early_stop:
                if self.args.plot:
                    plt.figure()
                    plt.plot(all_training_loss, label='train loss')
                    plt.plot(all_validation_loss, label='Val loss')
                    plt.legend()
                    plt.savefig(self.args.plot_dir + setting + '.png')
                    plt.show()
                print('Early stopping')
                break
        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
Esempio n. 3
0
    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y, batch_x_mark,
                    batch_y_mark) in enumerate(train_loader):
                iter_count += 1

                model_optim.zero_grad()

                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(
                    batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat(
                    [batch_y[:, :self.args.label_len, :], dec_inp],
                    dim=1).float().to(self.device)

                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark,
                                                 dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark,
                                                 dec_inp, batch_y_mark)

                        f_dim = -1 if self.args.features == 'MS' else 0
                        batch_y = batch_y[:, -self.args.pred_len:,
                                          f_dim:].to(self.device)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                             batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                             batch_y_mark)

                    if self.args.inverse:
                        outputs = train_data.inverse_transform(outputs)
                    f_dim = -1 if self.args.features == 'MS' else 0
                    batch_y = batch_y[:, -self.args.pred_len:,
                                      f_dim:].to(self.device)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(
                        i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                        speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1,
                                                   time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print(
                "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}"
                .format(epoch + 1, train_steps, train_loss, vali_loss,
                        test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
Esempio n. 4
0
    def train(self, ii, logger):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        next_data, next_loader = self._get_data(flag='train')
        test_data, test_loader = self._get_data(flag='test')
        if self.args.rank == 1:
            train_data, train_loader = self._get_data(flag='train')

        path = os.path.join(self.args.path, str(ii))
        try:
            os.mkdir(path)
        except FileExistsError:
            pass
        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True,
                                       rank=self.args.rank)

        W_optim, A_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            rate_counter = AverageMeter()
            Ag_counter, A_counter, Wg_counter, W_counter = AverageMeter(
            ), AverageMeter(), AverageMeter(), AverageMeter()

            self.model.train()
            epoch_time = time.time()
            for i, (trn_data, val_data, next_data) in enumerate(
                    zip(train_loader, vali_loader, next_loader)):
                for i in range(len(trn_data)):
                    trn_data[i], val_data[i], next_data[i] = trn_data[i].float(
                    ).to(self.device), val_data[i].float().to(
                        self.device), next_data[i].float().to(self.device)
                iter_count += 1
                A_optim.zero_grad()
                rate = self.arch.unrolled_backward(
                    self.args, trn_data, val_data, next_data,
                    W_optim.param_groups[0]['lr'], W_optim)
                rate_counter.update(rate)
                # for r in range(1, self.args.world_size):
                #     for n, h in self.model.named_H():
                #         if "proj.{}".format(r) in n:
                #             if self.args.rank <= r:
                #                 with torch.no_grad():
                #                     dist.all_reduce(h.grad)
                #                     h.grad *= self.args.world_size/r+1
                #             else:
                #                 z = torch.zeros(h.shape).to(self.device)
                #                 dist.all_reduce(z)
                for a in self.model.A():
                    with torch.no_grad():
                        dist.all_reduce(a.grad)
                a_g_norm = 0
                a_norm = 0
                n = 0
                for a in self.model.A():
                    a_g_norm += a.grad.mean()
                    a_norm += a.mean()
                    n += 1
                Ag_counter.update(a_g_norm / n)
                A_counter.update(a_norm / n)

                A_optim.step()

                W_optim.zero_grad()
                pred, true = self._process_one_batch(train_data, trn_data)
                loss = criterion(pred, true)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    logger.info(
                        "\tR{0} iters: {1}, epoch: {2} | loss: {3:.7f}".format(
                            self.args.rank, i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    logger.info(
                        '\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                            speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(W_optim)
                    scaler.update()
                else:
                    loss.backward()

                    w_g_norm = 0
                    w_norm = 0
                    n = 0
                    for w in self.model.W():
                        w_g_norm += w.grad.mean()
                        w_norm += w.mean()
                        n += 1
                    Wg_counter.update(w_g_norm / n)
                    W_counter.update(w_norm / n)

                    W_optim.step()

            logger.info("R{} Epoch: {} W:{} Wg:{} A:{} Ag:{} rate{}".format(
                self.args.rank, epoch + 1, W_counter.avg, Wg_counter.avg,
                A_counter.avg, Ag_counter.avg, rate_counter.avg))

            logger.info("R{} Epoch: {} cost time: {}".format(
                self.args.rank, epoch + 1,
                time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            logger.info(
                "R{0} Epoch: {1}, Steps: {2} | Train Loss: {3:.7f} Vali Loss: {4:.7f} Test Loss: {5:.7f}"
                .format(self.args.rank, epoch + 1, train_steps, train_loss,
                        vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)

            flag = torch.tensor(
                [1]) if early_stopping.early_stop else torch.tensor([0])
            flag = flag.to(self.device)
            flags = [
                torch.tensor([1]).to(self.device),
                torch.tensor([1]).to(self.device)
            ]
            dist.all_gather(flags, flag)
            if flags[0].item() == 1 and flags[1].item() == 1:
                logger.info("Early stopping")
                break

            adjust_learning_rate(W_optim, epoch + 1, self.args)

        best_model_path = path + '/' + '{}_checkpoint.pth'.format(
            self.args.rank)
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
        def start(self,
                  train_loader,
                  train_set,
                  valid_set=None,
                  valid_loader=None):
            self.train_num_batches = math.ceil(train_set.num_images /
                                               float(self.cf.train_batch_size))
            self.val_num_batches = 0 if valid_set is None else math.ceil(valid_set.num_images / \
                                                                    float(self.cf.valid_batch_size))
            # Define early stopping control
            if self.cf.early_stopping:
                early_stopping = EarlyStopping(self.cf)
            else:
                early_stopping = None

            prev_msg = '\nTotal estimated training time...\n'
            self.global_bar = ProgressBar(
                (self.cf.epochs + 1 - self.curr_epoch) *
                (self.train_num_batches + self.val_num_batches),
                lenBar=20)
            self.global_bar.set_prev_msg(prev_msg)

            # Train process
            for epoch in range(self.curr_epoch, self.cf.epochs + 1):
                # Shuffle train data
                train_set.update_indexes()

                # Initialize logger
                epoch_time = time.time()
                self.logger_stats.write('\t ------ Epoch: ' + str(epoch) +
                                        ' ------ \n')

                # Initialize epoch progress bar
                self.msg.accum_str = '\n\nEpoch %d/%d estimated time...\n' % \
                                     (epoch, self.cf.epochs)
                epoch_bar = ProgressBar(self.train_num_batches, lenBar=20)
                epoch_bar.update(show=False)

                # Initialize stats
                self.stats.epoch = epoch
                self.train_loss = AverageMeter()
                self.confm_list = np.zeros(
                    (self.cf.num_classes, self.cf.num_classes))

                # Train epoch
                self.training_loop(epoch, train_loader, epoch_bar)

                # Save stats
                self.stats.train.conf_m = self.confm_list
                self.compute_stats(np.asarray(self.confm_list),
                                   self.train_loss)
                self.save_stats_epoch(epoch)
                self.logger_stats.write_stat(
                    self.stats.train, epoch,
                    os.path.join(self.cf.train_json_path,
                                 'train_epoch_' + str(epoch) + '.json'))

                # Validate epoch
                self.validate_epoch(valid_set, valid_loader, early_stopping,
                                    epoch, self.global_bar)

                # Update scheduler
                if self.model.scheduler is not None:
                    self.model.scheduler.step(self.stats.val.loss)

                # Saving model if score improvement
                new_best = self.model.save(self.stats)
                if new_best:
                    self.logger_stats.write_best_stats(self.stats, epoch,
                                                       self.cf.best_json_file)

                # Update display values
                self.update_messages(epoch, epoch_time, new_best)

                if self.stop:
                    return

            # Save model without training
            if self.cf.epochs == 0:
                self.model.save_model()
    def train(self, setting):
        # データを取得, pytorchのライブラリを活用
        # data_set, data_loader
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')  # val??
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(
            patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()  # lossの計算方法

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):  # epoch 初期値は 6
            iter_count = 0
            train_loss = []
            train_loss_avg_list = []

            self.model.train()  # 1. modelのtrainを呼び出す
            epoch_time = time.time()
            # データローダをfor inで回すことによって扱いやすくなる
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(train_loader)):
                # print("Shape of batch_x")
                # print(batch_x.shape)
                iter_count += 1

                model_optim.zero_grad()  # 勾配の初期化
                # 学習時は model.eval()を呼ばない
                # ここからが本質 xとyが何者なのか
                pred, true = self._process_one_batch(
                    train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)  # 現在の出力と正しい値

                if self.args.interpret is True:
                    # 高いattentionをmaskしたときのpred
                    mask_attention_pred, true = self._process_one_batch(
                        train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)  # 現在の出力と正しい値
                    # mask_attention_output
                    loss = criterion(pred, true, mask_attention_pred)
                else:
                    loss = criterion(pred, true)  # 誤差計算

                train_loss.append(loss.item())

                if (i+1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(
                        i + 1, epoch + 1, loss.item()))
                    speed = (time.time()-time_now)/iter_count
                    left_time = speed * \
                        ((self.args.train_epochs - epoch)*train_steps - i)
                    print(
                        '\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()  # 誤差逆伝搬
                    model_optim.step()  # 更新

            # loss のデータをsaveしたい

            print("Epoch: {} cost time: {}".format(
                epoch+1, time.time()-epoch_time))
            train_loss_avg = np.average(train_loss)
            train_loss_avg_list.append(train_loss_avg)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss_avg, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch+1, self.args)

            # line notify
            if self.args.notify:
                send_line_notify(message="Epoch: {} cost time: {}".format(
                    epoch+1, time.time()-epoch_time)+"Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                    epoch + 1, train_steps, train_loss_avg, vali_loss, test_loss))

        best_model_path = path+'/'+'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))  # いつセーブした?

        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        # loss のsave
        np.save(folder_path+'/'+'train_loss_avg_list.npy', train_loss_avg_list)

        return self.model
Esempio n. 7
0
    def train(self, setting):
        train_data, train_loader = self._get_data(flag = 'train')
        vali_data, vali_loader = self._get_data(flag = 'val')
        test_data, test_loader = self._get_data(flag = 'test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()
        
        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        
        model_optim = self._select_optimizer()
        criterion =  self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            
            self.model.train()
            epoch_time = time.time()
            for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):
                iter_count += 1
                
                model_optim.zero_grad()
                pred, true = self._process_one_batch(
                    train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
                loss = criterion(pred, true)
                train_loss.append(loss.item())
                
                if (i+1) % 100==0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time()-time_now)/iter_count
                    left_time = speed*((self.args.train_epochs - epoch)*train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()
                
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch+1, self.args)
            
        best_model_path = path+'/'+'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        
        return self.model
Esempio n. 8
0
  def train(self, setting):
    train_data, train_loader = self._get_data(flag='train')
    valid_data, valid_loader = self._get_data(flag='val')

    print(f'number of batches in train data={len(train_loader)}')
    print(f'number of batches in valid data={len(valid_loader)}')

    path = './checkpoints/'+setting
    if not os.path.exists(path):
      os.makedirs(path)

    time_now = time.time()

    train_steps = len(train_loader)
    early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

    model_optim = self._select_optimizer()
    criterion =  self._select_criterion(self.args.data)

    #  print(self.model)
    best_utility = 0
    for epoch in range(self.args.train_epochs):
      iter_count = 0
      train_loss = []
      train_auc  = []

      self.model.train()

      # batch_x: (batch_size, seq_len, n_features)
      # batch_y: (batch_size, label_len + pred_len, n_features)
      # batch_x_mark: (batch_size, seq_len)
      # batch_y_mark: (batch_size, label_len + pred_len)


      for i, (s_begin, s_end, r_begin, r_end, batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
      #  for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):

        #  print(f'{i} s_begin: ', s_begin)
        #  print(f'{i} s_end: ', s_end)
        #  print(f'{i} r_begin: ', r_begin)
        #  print(f'{i} r_end: ',r_end)

        iter_count += 1

        #  print(f'x : {batch_x}')
        #  print(f'y : {batch_y}')
        #  print(f'x_mark : {batch_x_mark}')
        #  print(f'y_mark : {batch_y_mark}')

        model_optim.zero_grad()

        batch_x = batch_x.float().to(self.device)
        batch_y = batch_y.float()

        batch_x_mark = batch_x_mark.float().to(self.device)
        batch_y_mark = batch_y_mark.float().to(self.device)

        # decoder input
        dec_inp = torch.zeros_like(batch_y[:,-self.args.pred_len:,:]).float()
        dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)
        # encoder - decoder
        if self.args.output_attention:
          y_pred = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
        else:
          y_pred = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

        f_dim = -1 if self.args.features=='MS' else 0

        y_true = batch_y[:,-self.args.pred_len:,-self.args.c_out:].to(self.device)
        #  y_true = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)
        loss = criterion(y_pred, y_true)

        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        model_optim.step()

        #  print(y_pred)
        y_pred = np.where(y_pred.sigmoid().detach().cpu().numpy() >= 0.5, 1, 0).astype(int)
        y_true = np.where(y_true.sigmoid().detach().cpu().numpy() >= 0.5, 1, 0).astype(int)
        #  y_true = y_true.detach().cpu().numpy().astype(int)
        #  print('y_true: ', np.median(y_true, axis=1))
        #  print('y_pred: ', np.median(y_pred, axis=1))

        train_auc.append(roc_auc_score(np.median(y_true, axis=1), np.median(y_pred, axis=1)))

        loss = loss.item()
        train_loss.append(loss)


        if (i+1) % 100 == 0:
          print(f'\ttrain_iters={i+1} | epoch={epoch+1} | ' \
                f'batch_loss={loss:.4f} | running_loss={np.mean(train_loss):.4f} | running_auc={np.mean(train_auc):.4f}')
          speed = (time.time()-time_now)/iter_count
          left_time = speed*((self.args.train_epochs - epoch)*train_steps - i)
          print(f'\tspeed={speed:.4f}s/batch | left_time={left_time:.4f}s')
          #  print(torch.cuda.memory_summary(abbreviated=True))

          iter_count = 0
          time_now = time.time()

        del batch_x, batch_x_mark, batch_y, batch_y_mark, dec_inp, y_true, y_pred


      gc.collect()

      torch.cuda.empty_cache()
      torch.cuda.reset_max_memory_allocated(self.device)

      returns = self.evaluate(valid_data, valid_loader, criterion)
      valid_loss, valid_preds, valid_trues, v_start, v_end = returns

      #  print(valid_data.data_x[b_start:b_end, -self.args.c_out:].shape)
      #  print('before where: ', valid_preds)
      #  print(valid_data.data_x[b_start:b_end, -self.args.c_out:])
      #  valid_trues = valid_data.data_x[b_start:b_end, -self.args.c_out:]
      #  print('valid_preds.shape: ', valid_preds.shape)
      valid_preds = np.median(valid_preds, axis=1)
      print(pd.DataFrame(valid_preds).describe())
      valid_preds = np.where(valid_preds >= 0.5, 1, 0).astype(int)

      #  print('after where: ', valid_preds)
      #  valid_trues = 1/(1+np.exp(-valid_trues))
      valid_trues = np.median(valid_trues, axis=1)
      valid_trues = np.where(valid_trues >= 0.5, 1, 0).astype(int)
      #  print('valid_trues shape: ', valid_trues.shape)
      #  print('valid_trues: ', valid_trues)
      valid_auc = roc_auc_score(valid_trues, valid_preds)
      valid_u_score = utility_score_bincount(date=valid_data.data_stamp[v_start:v_end],
                                              weight=valid_data.weight[v_start:v_end],
                                              resp=valid_data.resp[v_start:v_end],
                                              action=valid_preds)
      max_u_score = utility_score_bincount(date=valid_data.data_stamp[v_start:v_end],
                                            weight=valid_data.weight[v_start:v_end],
                                            resp=valid_data.resp[v_start:v_end],
                                            action=valid_trues)

      best_utility = max(best_utility, valid_u_score)

      print(f'epoch={epoch+1} | ' \
            f'average_train_loss={np.mean(train_loss):.4f} | average_valid_loss={valid_loss:.4f} | '
            f'valid_utility={valid_u_score:.4f}/{max_u_score:.4f} | valid_auc={valid_auc:.4f}')


      early_stopping(valid_auc, self.model, path)
      if early_stopping.early_stop:
        print("Early stopping")
        print(f"Best utility score is {best_utility:.4f}")
        break

      adjust_learning_rate(model_optim, epoch+1, self.args)

    best_model_path = path+'/'+'checkpoint.pth'
    self.model.load_state_dict(torch.load(best_model_path))

    return self.model
    def train(self, setting):
        print(self.model)
        print(sum(p.numel() for p in self.model.parameters()))
        train_data, train_loader = self._get_data(
            flag='train',
            data_dir=
            "/mnt/ufs18/home-052/surunze/biostat_project/archive_1/transcheckkernels1200/dataset/"
        )
        print("train data loaded")
        vali_data, vali_loader = self._get_data(
            flag='val',
            data_dir=
            "/mnt/ufs18/home-052/surunze/biostat_project/archive_1/transcheckkernels1200/dataset/"
        )
        print("valid data loaded")
        test_data, test_loader = self._get_data(
            flag='test',
            data_dir=
            "/mnt/ufs18/home-052/surunze/biostat_project/archive_1/transcheckkernels1200/dataset/"
        )
        print("test data loaded")

        s = train_data
        print("train data train data", len(s[0]))
        print("train data train data", s[1][0].shape)
        print("train data train data", (s[0][0].shape), (s[1][0].shape),
              (s[2][0].shape), (s[3][0].shape))
        print("train data train data", (s[0][1].shape), (s[1][1].shape),
              (s[2][1].shape), (s[3][1].shape))
        print("train data train data", (s[0][2].shape), (s[1][2].shape),
              (s[2][2].shape), (s[3][2].shape))
        print("train data train data", (s[0][3].shape), (s[1][3].shape),
              (s[2][3].shape), (s[3][3].shape))

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y, batch_x_mark,
                    batch_y_mark) in enumerate(train_loader):
                iter_count += 1

                model_optim.zero_grad()

                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(
                    batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat(
                    [batch_y[:, :self.args.label_len, :], dec_inp],
                    dim=1).float().to(self.device)

                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark,
                                                 dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark,
                                                 dec_inp, batch_y_mark)

                        f_dim = -1 if self.args.features == 'MS' else 0
                        batch_y = batch_y[:, -self.args.pred_len:,
                                          f_dim:].to(self.device)
                        #print(outputs.shape, batch_y.shape)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                             batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                             batch_y_mark)

                    if self.args.inverse:
                        outputs = train_data.inverse_transform(outputs)
                    f_dim = -1 if self.args.features == 'MS' else 0
                    batch_y = batch_y[:, -self.args.pred_len:,
                                      f_dim:].to(self.device)
                    #print(outputs.shape, batch_y.shape)
                    loss = criterion(outputs[:, :, 0], batch_y[:, :, 0])
                    train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(
                        i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                        speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1,
                                                   time.time() - epoch_time))
            train_loss = np.average(train_loss)
            #vali_loss = train_loss
            #test_loss = train_loss
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.test(setting)

            print("Training Summary", epoch + 1, train_steps, train_loss,
                  vali_loss, test_loss)
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
Esempio n. 10
0
    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        total_para, trainable_para = self._get_number_parameters()
        print('Total number of parameters: {:d}'.format(total_para))
        print('Number of trainable parameters: {:d}'.format(trainable_para))
        path = './checkpoints/' + setting
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            for i, (batch_x, batch_y, batch_x_mark,
                    batch_y_mark) in enumerate(train_loader):
                iter_count += 1

                model_optim.zero_grad()

                batch_x = batch_x.double().to(self.device)
                batch_y = batch_y.double()

                batch_x_mark = batch_x_mark.double().to(self.device)
                batch_y_mark = batch_y_mark.double().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(
                    batch_y[:, -self.args.pred_len:, :]).double()
                dec_inp = torch.cat(
                    [batch_y[:, :self.args.label_len, :], dec_inp],
                    dim=1).double().to(self.device)
                # encoder - decoder
                outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                     batch_y_mark)

                batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)
                loss = criterion(outputs, batch_y)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(
                        i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                        speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                model_optim.step()

            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print(
                "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}"
                .format(epoch + 1, train_steps, train_loss, vali_loss,
                        test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
  def train(self, setting):
    """Training Function.

    Args:
        setting: Name used to save the model

    Returns:
        model: Trained model
    """
    # Load different datasets
    train_loader, train_loader_shuffled = self._get_data(flag="train")
    vali_loader = self._get_data(flag="val")
    test_loader = self._get_data(flag="test")

    path = os.path.join(self.args.checkpoints, setting)
    if not os.path.exists(path):
      os.makedirs(path)

    train_steps = len(train_loader)
    early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

    # Setting optimizer and loss functions
    model_optim, gate_optim = self._select_optimizer()
    criterion, criterion_kl = self._select_criterion()

    loss_train = []
    acc_loss_train = []
    utilization_loss_train = []
    smoothness_loss_train = []
    diversity_loss_train = []
    loss_val = []
    acc_loss_val = []
    utilization_loss_val = []
    smoothness_loss_val = []
    diversity_loss_val = []

    mse_train_ = []
    mse_test_ = []
    mse_val_ = []
    upper_bound_train_ = []
    upper_bound_test_ = []
    upper_bound_val_ = []
    oracle_acc_test_ = []
    oracle_acc_train_ = []
    oracle_acc_val_ = []

    # Getting the intial mse, oracle accuracy and upper bound
    (mse_train, upper_bound_train,
     oracle_acc_train) = self.get_upper_bound_accuracy(
         train_loader, flag="train")
    (mse_test, upper_bound_test,
     oracle_acc_test) = self.get_upper_bound_accuracy(
         test_loader, flag="test")
    (mse_val, upper_bound_val, oracle_acc_val) = self.get_upper_bound_accuracy(
        vali_loader, flag="val")

    mse_train_.append(mse_train)
    mse_test_.append(mse_test)
    mse_val_.append(mse_val)

    upper_bound_train_.append(upper_bound_train)
    upper_bound_test_.append(upper_bound_test)
    upper_bound_val_.append(upper_bound_val)

    oracle_acc_train_.append(oracle_acc_train)
    oracle_acc_test_.append(oracle_acc_test)
    oracle_acc_val_.append(oracle_acc_val)

    # Training loop
    for epoch in range(self.args.train_epochs):

      self.model.train()
      loss_all = 0

      # Add noise to the weights of the expert this promotes diversity
      if self.args.noise:
        with torch.no_grad():
          for param in self.model.experts.parameters():
            param.add_(torch.randn(param.size()).to(self.device) * 0.01)

      for i, (batch_x, index, batch_y) in enumerate(train_loader_shuffled):
        # get past error made by experts
        past_errors = self.get_past_errors(index, "train")

        model_optim.zero_grad()
        if self.args.expert_type == "SDT":
          pred, true, weights, (reg_out, panelty) = self._process_one_batch(
              batch_x, batch_y, past_errors=past_errors)
          accuracy_loss = self.accuracy_loss(pred, true, weights) + panelty
        else:
          pred, true, weights, reg_out = self._process_one_batch(
              batch_x, batch_y, past_errors=past_errors)
          accuracy_loss = self.accuracy_loss(pred, true, weights)

        batch_size = pred.shape[0]
        # Calcuate gate loss
        gate_loss = criterion(reg_out, true)

        # Calcuate utilization loss
        if self.args.utilization_hp != 0:
          batch_expert_utilization = torch.sum(
              weights.squeeze(-1), dim=0) / batch_size
          expert_utilization_loss = self.expert_utilization_loss(
              batch_expert_utilization)
        else:
          expert_utilization_loss = 0

        # Calcuate smoothness loss
        if self.args.smoothness_hp != 0:
          previous_weight = self.get_gate_assignment_weights(index, "train")
          smoothness_loss = criterion_kl(
              torch.log(weights.squeeze(-1) + eps), previous_weight)
          self.set_gate_assignment_weights(
              weights.squeeze(-1).detach(), index, "train")
        else:
          smoothness_loss = 0

        # Calcuate diversity loss
        if self.args.diversity_hp != 0:
          batch_x_noisy = add_gaussian_noise(batch_x)
          pred_noisy, _, _, _ = self._process_one_batch(
              batch_x_noisy, batch_y, past_errors=past_errors)
          diversity_loss = self.diversity_loss(pred, pred_noisy)
          # avg_diversity_loss += self.args.diversity_hp * diversity_loss.item()
        else:
          diversity_loss = 0

        # set expert error
        error = self.model_assignment_error(pred, true)
        self.set_past_errors(error.detach(), index, "train")

        loss = (
            self.args.accuracy_hp * accuracy_loss +
            self.args.gate_hp * gate_loss +
            self.args.utilization_hp * expert_utilization_loss +
            self.args.smoothness_hp * smoothness_loss +
            self.args.diversity_hp * diversity_loss)

        loss.backward()
        model_optim.step()
        loss_all += loss.item()

        if (i + 1) % 50 == 0:
          print("\tOne iters: {0}/{1}, epoch: {2} | loss: {3:.7f}".format(
              i + 1, len(train_loader), epoch + 1, loss_all / i))
      # update past error matrix
      self.error_scaler.fit(
          self.past_train_error.detach().cpu().numpy().flatten().reshape(-1, 1))
      self.past_train_error = torch.Tensor(
          self.error_scaler.transform(
              self.past_train_error.detach().cpu().numpy().flatten().reshape(
                  -1, 1))).reshape(-1, self.num_experts).to(self.device)

      # Getting different losses
      (train_loss, acc_train, utilization_train, smoothness_train,
       diversity_train) = self.vali(train_loader, "train")

      loss_train.append(train_loss)
      acc_loss_train.append(acc_train)
      utilization_loss_train.append(utilization_train)
      smoothness_loss_train.append(smoothness_train)
      diversity_loss_train.append(diversity_train)

      (val_loss, acc_val, utilization_val, smoothness_val,
       diversity_val) = self.vali(
           vali_loader, flag="val")
      loss_val.append(val_loss)
      acc_loss_val.append(acc_val)
      utilization_loss_val.append(utilization_val)
      smoothness_loss_val.append(smoothness_val)
      diversity_loss_val.append(diversity_val)

      # getting mse, oracle accuracy and upper bound
      (mse_train, upper_bound_train,
       oracle_acc_train) = self.get_upper_bound_accuracy(
           train_loader, flag="train")
      (mse_test, upper_bound_test,
       oracle_acc_test) = self.get_upper_bound_accuracy(
           test_loader, flag="test")
      (mse_val, upper_bound_val,
       oracle_acc_val) = self.get_upper_bound_accuracy(
           vali_loader, flag="val")

      mse_train_.append(mse_train)
      mse_test_.append(mse_test)
      mse_val_.append(mse_val)

      upper_bound_train_.append(upper_bound_train)
      upper_bound_test_.append(upper_bound_test)
      upper_bound_val_.append(upper_bound_val)

      oracle_acc_train_.append(oracle_acc_train)
      oracle_acc_test_.append(oracle_acc_test)
      oracle_acc_val_.append(oracle_acc_test)

      # early stopping depends on the validation accuarcy loss
      early_stopping(acc_val, self.model, path)

      print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}"
            .format(epoch + 1, train_steps, train_loss, val_loss))

      if ((epoch + 1) % 10 == 0 and self.args.plot):

        self.plot_all(setting, mse_test_, mse_train_, mse_val_,
                      upper_bound_test_, upper_bound_train_, upper_bound_val_,
                      oracle_acc_test_, oracle_acc_train_, oracle_acc_val_,
                      loss_train, loss_val, acc_loss_train, acc_loss_val,
                      utilization_loss_train, utilization_loss_val,
                      smoothness_loss_train, smoothness_loss_val,
                      diversity_loss_train, diversity_loss_val)

      # when training runs out of patience
      if early_stopping.early_stop:
        break

    # if freezing experts and tunning gate network
    if (self.args.freeze and oracle_acc_train != 1):

      # load the best model
      best_model_path = path + "/" + "checkpoint.pth"
      self.model.load_state_dict(torch.load(best_model_path))

      # set past errors to zero
      self.past_train_error[:, :] = 0
      self.past_test_error[:, :] = 0
      self.past_val_error[:, :] = 0

      # get validation accuracy on the best model
      (val_loss, acc_val, utilization_val, smoothness_val,
       diversity_val) = self.vali(
           vali_loader, flag="val")

      # reseting and adjusting early_stopping
      early_stopping.val_loss_min = acc_val
      early_stopping.counter = 0
      early_stopping.early_stop = False

      for e in range(self.args.train_epochs):
        self.model.train()
        loss_all = 0
        for i, (batch_x, index, batch_y) in enumerate(train_loader_shuffled):
          past_errors = self.get_past_errors(index, "train")
          gate_optim.zero_grad()

          if self.args.expert_type == "SDT":
            pred, true, weights, (reg_out, panelty) = self._process_one_batch(
                batch_x, batch_y, past_errors=past_errors)
            accuracy_loss = self.accuracy_loss(pred, true, weights)
          else:
            pred, true, weights, reg_out = self._process_one_batch(
                batch_x, batch_y, past_errors=past_errors)
            accuracy_loss = self.accuracy_loss(pred, true, weights)

          # set expert error
          error = self.model_assignment_error(pred, true)
          self.set_past_errors(error.detach(), index, "train")

          loss = accuracy_loss
          loss.backward()

          # clear the expert gradients since we want them frozen
          self.model.experts.zero_grad()

          gate_optim.step()
          loss_all += loss.item()

          if (i + 1) % 50 == 0:
            print(
                "\tFreeze iters: {0}/{1}, epoch: {2}  sub epoch {4} | loss: {3:.7f}"
                .format(i + 1, len(train_loader), epoch + 1, loss_all / i, e))

        # update past error matrix
        self.error_scaler.fit(
            self.past_train_error.detach().cpu().numpy().flatten().reshape(
                -1, 1))
        self.past_train_error = torch.Tensor(
            self.error_scaler.transform(
                self.past_train_error.detach().cpu().numpy().flatten().reshape(
                    -1, 1))).reshape(-1, self.num_experts).to(self.device)

        # Getting different losses
        (train_loss, acc_train, utilization_train, smoothness_train,
         diversity_train) = self.vali(
             train_loader, flag="train")

        loss_train.append(train_loss)
        acc_loss_train.append(acc_train)
        utilization_loss_train.append(utilization_train)
        smoothness_loss_train.append(smoothness_train)
        diversity_loss_train.append(diversity_train)

        (val_loss, acc_val, utilization_val, smoothness_val,
         diversity_val) = self.vali(
             vali_loader, flag="val")
        loss_val.append(val_loss)
        acc_loss_val.append(acc_val)
        utilization_loss_val.append(utilization_val)
        smoothness_loss_val.append(smoothness_val)
        diversity_loss_val.append(diversity_val)

        # getting mse, oracle accuracy and upper bound
        (mse_train, upper_bound_train,
         oracle_acc_train) = self.get_upper_bound_accuracy(
             train_loader, flag="train")
        (mse_test, upper_bound_test,
         oracle_acc_test) = self.get_upper_bound_accuracy(
             test_loader, flag="test")
        (mse_val, upper_bound_val,
         oracle_acc_val) = self.get_upper_bound_accuracy(
             vali_loader, flag="val")

        mse_train_.append(mse_train)
        mse_test_.append(mse_test)
        mse_val_.append(mse_val)

        upper_bound_train_.append(upper_bound_train)
        upper_bound_test_.append(upper_bound_test)
        upper_bound_val_.append(upper_bound_val)

        oracle_acc_train_.append(oracle_acc_train)
        oracle_acc_test_.append(oracle_acc_test)
        oracle_acc_val_.append(oracle_acc_test)

        # early stopping depends on the validation accuarcy loss
        early_stopping(acc_val, self.model, path)

        print(
            "Frozen Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}"
            .format(e + 1, train_steps, train_loss, val_loss))

        if early_stopping.early_stop:
          print("Early stopping")
          break

    if self.args.plot:
      self.plot_all(setting, mse_test_, mse_train_, mse_val_, upper_bound_test_,
                    upper_bound_train_, upper_bound_val_, oracle_acc_test_,
                    oracle_acc_train_, oracle_acc_val_, loss_train, loss_val,
                    acc_loss_train, acc_loss_val, utilization_loss_train,
                    utilization_loss_val, smoothness_loss_train,
                    smoothness_loss_val, diversity_loss_train,
                    diversity_loss_val)
    # Load the best model on the validation dataset
    best_model_path = path + "/" + "checkpoint.pth"
    self.model.load_state_dict(torch.load(best_model_path))
    pickle.dump(self.error_scaler, open(path + "/" + "std_scaler.bin", "wb"))
    return self.model
Esempio n. 12
0
    def train(self, setting):
        print('prepare data...')
        train_data_loaders, vali_data_loaders, test_data_loaders = self._get_data(
        )
        print('Number of data loaders:', len(train_data_loaders))
        path = './checkpoints/' + setting
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            self.model.train()
            for index in range(len(train_data_loaders)):
                train_loader = train_data_loaders[index]
                train_loss = []
                begin_ = time.time()
                for i, (batch_x, batch_y) in enumerate(train_loader):
                    iter_count += 1

                    model_optim.zero_grad()

                    batch_x = batch_x.double()  # .to(self.device)
                    batch_y = batch_y.double()

                    outputs = self.model(batch_x).view(-1, 24)
                    batch_y = batch_y[:, -self.args.pred_len:,
                                      -1].view(-1, 24)  # .to(self.device)

                    loss = criterion(outputs, batch_y)  # + 0.1*corr

                    train_loss.append(loss.item())

                    loss.backward()
                    model_optim.step()
                print('INDEX Finished', index, 'train loss',
                      np.average(train_loss), 'COST',
                      time.time() - begin_)
            train_loss = np.average(train_loss)
            vali_loss, mae, score = self.test('1')
            early_stopping(-score, self.model, path)
            print(
                "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} score: {4:.7f}"
                .format(epoch + 1, 0, np.average(train_loss), vali_loss,
                        score))

            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        print('Model is saved at', best_model_path)
        self.model.eval()
        return self.model