Ejemplo n.º 1
0
    def construct_model(self):
        """get data loader"""
        input_size, input_channels, n_classes, train_data = get_data(
            self.config.dataset,
            self.config.data_path,
            cutout_length=0,
            validation=False)

        n_train = len(train_data)
        split = n_train // 2
        indices = list(range(n_train))
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[:split])
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:])
        self.train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.config.batch_size,
            sampler=train_sampler,
            num_workers=self.config.workers,
            pin_memory=True)
        self.valid_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.config.batch_size,
            sampler=valid_sampler,
            num_workers=self.config.workers,
            pin_memory=True)
        """build model"""
        print("init model")
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        model = SearchStageController(input_channels,
                                      self.config.init_channels,
                                      n_classes,
                                      self.config.layers,
                                      self.criterion,
                                      self.config.genotype,
                                      device_ids=self.config.gpus)
        self.model = model.to(self.device)
        print("init model end!")
        """build optimizer"""
        print("get optimizer")
        self.w_optim = torch.optim.SGD(self.model.weights(),
                                       self.config.w_lr,
                                       momentum=self.config.w_momentum,
                                       weight_decay=self.config.w_weight_decay)
        self.alpha_optim = torch.optim.Adam(
            self.model.alphas(),
            self.config.alpha_lr,
            betas=(0.5, 0.999),
            weight_decay=self.config.alpha_weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.w_optim, self.total_epochs, eta_min=self.config.w_lr_min)
        self.architect = Architect(self.model, self.config.w_momentum,
                                   self.config.w_weight_decay)
Ejemplo n.º 2
0
    def __init__(self, args):

        self.args = args
        self.console = Console()

        self.console.log('=> [1] Initial settings')
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        cudnn.benchmark = True
        cudnn.enabled = True

        self.console.log('=> [2] Initial models')
        self.metric = load_metric(args)
        self.loss_fn = get_loss_fn(args).cuda()
        self.model = Model_Search(args, get_trans_input(args),
                                  self.loss_fn).cuda()
        self.console.log(
            f'=> Supernet Parameters: {count_parameters_in_MB(self.model)}',
            style='bold red')

        self.console.log(f'=> [3] Preparing dataset')
        self.dataset = load_data(args)
        if args.pos_encode > 0:
            #! add positional encoding
            self.console.log(f'==> [3.1] Adding positional encodings')
            self.dataset._add_positional_encodings(args.pos_encode)
        self.search_data = self.dataset.train
        self.val_data = self.dataset.val
        self.test_data = self.dataset.test
        self.load_dataloader()

        self.console.log(f'=> [4] Initial optimizer')
        self.optimizer = torch.optim.SGD(params=self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=self.optimizer,
            T_max=float(args.epochs),
            eta_min=args.lr_min)

        self.architect = Architect(self.model, self.args)
Ejemplo n.º 3
0
    def _build_model(self):
        model_dict = {
            'informer': Informer,
            'informerstack': InformerStack,
        }
        if self.args.model == 'informer' or self.args.model == 'informerstack':
            e_layers = self.args.e_layers if self.args.model == 'informer' else self.args.s_layers
            model = model_dict[self.args.model](
                self.args.enc_in,
                self.args.dec_in,
                self.args.c_out,
                self.args.seq_len,
                self.args.label_len,
                self.args.pred_len,
                self.args.factor,
                self.args.d_model,
                self.args.n_heads,
                e_layers,  # self.args.e_layers,
                self.args.d_layers,
                self.args.d_ff,
                self.args.dropout,
                self.args.attn,
                self.args.embed,
                self.args.freq,
                self.args.activation,
                self.args.output_attention,
                self.args.distil,
                self.args.mix,
                self.device,
                self.args).float()
        else:
            raise NotImplementedError
        # something

        self.arch = Architect(model, self.device, self.args,
                              self._select_criterion())
        return model
Ejemplo n.º 4
0
class SearchCellTrainer():
    def __init__(self, config):
        self.config = config

        self.world_size = 1
        self.gpu = self.config.local_rank
        self.save_epoch = 1
        self.ckpt_path = self.config.path
        """get the train parameters"""
        self.total_epochs = self.config.epochs
        self.train_batch_size = self.config.batch_size
        self.val_batch_size = self.config.batch_size
        self.global_batch_size = self.world_size * self.train_batch_size
        self.max_lr = self.config.w_lr * self.world_size
        """construct the whole network"""
        self.resume_path = self.config.resume_path
        if torch.cuda.is_available():
            # self.device = torch.device(f'cuda:{self.gpu}')
            # torch.cuda.set_device(self.device)
            torch.cuda.set_device(self.config.gpus[0])
            # cudnn.benchmark = True
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.construct_model()

        self.steps = 0
        self.log_step = 10
        self.logger = self.config.logger
        self.writer = SummaryWriter(
            log_dir=os.path.join(self.config.path, "tb"))
        self.writer.add_text('config', config.as_markdown(), 0)

    def construct_model(self):
        """get data loader"""
        input_size, input_channels, n_classes, train_data = get_data(
            self.config.dataset,
            self.config.data_path,
            cutout_length=0,
            validation=False)

        n_train = len(train_data)
        split = n_train // 2
        indices = list(range(n_train))
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[:split])
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:])
        self.train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.config.batch_size,
            sampler=train_sampler,
            num_workers=self.config.workers,
            pin_memory=True)
        self.valid_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.config.batch_size,
            sampler=valid_sampler,
            num_workers=self.config.workers,
            pin_memory=True)
        """build model"""
        print("init model")
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        model = SearchCellController(input_channels,
                                     self.config.init_channels,
                                     n_classes,
                                     self.config.layers,
                                     self.criterion,
                                     device_ids=self.config.gpus)
        self.model = model.to(self.device)
        print("init model end!")
        """build optimizer"""
        print("get optimizer")
        self.w_optim = torch.optim.SGD(self.model.weights(),
                                       self.config.w_lr,
                                       momentum=self.config.w_momentum,
                                       weight_decay=self.config.w_weight_decay)
        self.alpha_optim = torch.optim.Adam(
            self.model.alphas(),
            self.config.alpha_lr,
            betas=(0.5, 0.999),
            weight_decay=self.config.alpha_weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.w_optim, self.total_epochs, eta_min=self.config.w_lr_min)
        self.architect = Architect(self.model, self.config.w_momentum,
                                   self.config.w_weight_decay)

    def resume_model(self, model_path=None):
        if model_path is None and not self.resume_path:
            self.start_epoch = 0
            self.logger.info("--> No loaded checkpoint!")
        else:
            model_path = model_path or self.resume_path
            checkpoint = torch.load(model_path, map_location=self.device)

            self.start_epoch = checkpoint['epoch']
            self.steps = checkpoint['steps']
            self.model.load_state_dict(checkpoint['model'], strict=True)
            self.w_optim.load_state_dict(checkpoint['w_optim'])
            self.alpha_optim.load_state_dict(checkpoint['alpha_optim'])
            self.logger.info(
                f"--> Loaded checkpoint '{model_path}'(epoch {self.start_epoch})"
            )

    def save_checkpoint(self, epoch, is_best=False):
        if epoch % self.save_epoch == 0:
            state = {
                'config': self.config,
                'epoch': epoch,
                'steps': self.steps,
                'model': self.model.state_dict(),
                'w_optim': self.w_optim.state_dict(),
                'alpha_optim': self.alpha_optim.state_dict()
            }
            if is_best:
                best_filename = os.path.join(self.ckpt_path, 'best.pth.tar')
                torch.save(state, best_filename)

    def train_epoch(self, epoch, printer=print):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        cur_lr = self.lr_scheduler.get_last_lr()[0]

        self.model.print_alphas(self.logger)
        self.model.train()

        prefetcher_trn = data_prefetcher(self.train_loader)
        prefetcher_val = data_prefetcher(self.valid_loader)
        trn_X, trn_y = prefetcher_trn.next()
        val_X, val_y = prefetcher_val.next()
        i = 0
        while trn_X is not None:
            i += 1
            N = trn_X.size(0)
            self.steps += 1

            # architect step (alpha)
            self.alpha_optim.zero_grad()
            self.architect.unrolled_backward(trn_X, trn_y, val_X, val_y,
                                             cur_lr, self.w_optim)
            self.alpha_optim.step()

            # child network step (w)
            self.w_optim.zero_grad()
            logits = self.model(trn_X)
            loss = self.model.criterion(logits, trn_y)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.weights(),
                                     self.config.w_grad_clip)
            self.w_optim.step()

            prec1, prec5 = accuracy(logits, trn_y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if self.steps % self.log_step == 0:
                self.writer.add_scalar('train/lr', round(cur_lr, 5),
                                       self.steps)
                self.writer.add_scalar('train/loss', loss.item(), self.steps)
                self.writer.add_scalar('train/top1', prec1.item(), self.steps)
                self.writer.add_scalar('train/top5', prec5.item(), self.steps)

            if i % self.config.print_freq == 0 or i == len(
                    self.train_loader) - 1:
                printer(
                    f'Train: Epoch: [{epoch}][{i}/{len(self.train_loader) - 1}]\t'
                    f'Step {self.steps}\t'
                    f'lr {round(cur_lr, 5)}\t'
                    f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                    f'Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})\t')

            trn_X, trn_y = prefetcher_trn.next()
            val_X, val_y = prefetcher_val.next()

        printer("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(
            epoch, self.total_epochs - 1, top1.avg))

    def val_epoch(self, epoch, printer):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        self.model.eval()
        prefetcher = data_prefetcher(self.valid_loader)
        X, y = prefetcher.next()
        i = 0

        with torch.no_grad():
            while X is not None:
                N = X.size(0)
                i += 1

                logits = self.model(X)
                loss = self.criterion(logits, y)

                prec1, prec5 = accuracy(logits, y, topk=(1, 5))
                losses.update(loss.item(), N)
                top1.update(prec1.item(), N)
                top5.update(prec5.item(), N)

                if i % self.config.print_freq == 0 or i == len(
                        self.valid_loader) - 1:
                    printer(
                        f'Valid: Epoch: [{epoch}][{i}/{len(self.valid_loader)}]\t'
                        f'Step {self.steps}\t'
                        f'Loss {losses.avg:.4f}\t'
                        f'Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})')

                X, y = prefetcher.next()

        self.writer.add_scalar('val/loss', losses.avg, self.steps)
        self.writer.add_scalar('val/top1', top1.avg, self.steps)
        self.writer.add_scalar('val/top5', top5.avg, self.steps)

        printer("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(
            epoch, self.total_epochs - 1, top1.avg))

        return top1.avg
Ejemplo n.º 5
0
def start(args):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    dataset = LoadData(args.data_name)
    if args.data_name == 'SBM_PATTERN':
        in_dim = 3
        num_classes = 2
    elif args.data_name == 'SBM_CLUSTER':
        in_dim = 7
        num_classes = 6
    print(f"input dimension: {in_dim}, number classes: {num_classes}")

    criterion = MyCriterion(num_classes)
    criterion = criterion.cuda()

    model = Network(args.layers, args.nodes, in_dim, args.feature_dim, num_classes, criterion, args.data_type, args.readout)
    model = model.cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    train_data, val_data, test_data = dataset.train, dataset.val, dataset.test

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    print(f"train set full size : {num_train}; split train set size : {split}")
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate)

    true_valid_queue = torch.utils.data.DataLoader(
        val_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate)

    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate)

    optimizer = torch.optim.SGD(model.parameters(),args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    architect = Architect(model, args)

    # viz = Visdom(env = '{} {}'.format(args.data_name,  time.asctime(time.localtime(time.time()))  ))
    viz = None
    save_file = open(args.save_result, "w")
    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('[LR]\t%f', lr)

        if epoch % args.save_freq == 0:
            print(model.show_genotypes())
            save_file.write(f"Epoch : {epoch}\n{model.show_genotypes()}\n")
            for i in range(args.layers):
                logging.info('layer = %d', i)
                genotype = model.show_genotype(i)
                logging.info('genotype = %s', genotype)
            '''
            w1, w2, w3 = model.show_weights(0)
            print('[1] weights in first cell\n',w1)
            print('[2] weights in middle cell\n', w2)
            print('[3] weights in last cell\n', w3)
            '''
        # training
        macro_acc, micro_acc, loss = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, viz)
        # true validation
        macro_acc, micro_acc, loss = infer(true_valid_queue, model, criterion, stage = 'validating')
        # testing
        macro_acc, micro_acc, loss = infer(test_queue, model, criterion, stage = ' testing  ')
Ejemplo n.º 6
0
class Exp_M_Informer(Exp_Basic):
    def __init__(self, args):
        super(Exp_M_Informer, self).__init__(args)

    def _build_model(self):
        model_dict = {
            'informer': Informer,
            'informerstack': InformerStack,
        }
        if self.args.model == 'informer' or self.args.model == 'informerstack':
            e_layers = self.args.e_layers if self.args.model == 'informer' else self.args.s_layers
            model = model_dict[self.args.model](
                self.args.enc_in,
                self.args.dec_in,
                self.args.c_out,
                self.args.seq_len,
                self.args.label_len,
                self.args.pred_len,
                self.args.factor,
                self.args.d_model,
                self.args.n_heads,
                e_layers,  # self.args.e_layers,
                self.args.d_layers,
                self.args.d_ff,
                self.args.dropout,
                self.args.attn,
                self.args.embed,
                self.args.freq,
                self.args.activation,
                self.args.output_attention,
                self.args.distil,
                self.args.mix,
                self.device,
                self.args).float()
        else:
            raise NotImplementedError
        # something

        self.arch = Architect(model, self.device, self.args,
                              self._select_criterion())
        return model

    def _get_data(self, flag):
        args = self.args

        data_dict = {
            'ETTh1': Dataset_ETT_hour,
            'ETTh2': Dataset_ETT_hour,
            'ETTm1': Dataset_ETT_minute,
            'ETTm2': Dataset_ETT_minute,
            'WTH': Dataset_Custom,
            'ECL': Dataset_Custom,
            'Solar': Dataset_Custom,
            'custom': Dataset_Custom,
        }
        Data = data_dict[self.args.data]
        timeenc = 0 if args.embed != 'timeF' else 1

        if flag == 'test':
            shuffle_flag = False
            drop_last = True
            batch_size = args.batch_size
            freq = args.freq
        elif flag == 'pred':
            shuffle_flag = False
            drop_last = False
            batch_size = 1
            freq = args.detail_freq
            Data = Dataset_Pred
        else:
            shuffle_flag = True
            drop_last = True
            batch_size = args.batch_size
            freq = args.freq
        data_set = Data(root_path=args.root_path,
                        data_path=args.data_path,
                        flag=flag,
                        size=[args.seq_len, args.label_len, args.pred_len],
                        features=args.features,
                        target=args.target,
                        inverse=args.inverse,
                        timeenc=timeenc,
                        freq=freq,
                        cols=args.cols)
        data_loader = DataLoader(data_set,
                                 batch_size=batch_size,
                                 shuffle=shuffle_flag,
                                 num_workers=args.num_workers,
                                 drop_last=drop_last)

        return data_set, data_loader

    def _select_optimizer(self):
        W_optim = optim.Adam(self.model.W(), lr=self.args.learning_rate)
        A_optim = optim.Adam(self.model.A(),
                             self.args.A_lr,
                             betas=(0.5, 0.999),
                             weight_decay=self.args.A_weight_decay)
        return W_optim, A_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        self.model.eval()
        total_loss = []
        for i, val_d in enumerate(vali_loader):
            pred, true = self._process_one_batch(vali_data, val_d)
            loss = criterion(pred.detach().cpu(), true.detach().cpu())
            total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self, ii, logger):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        next_data, next_loader = self._get_data(flag='train')
        test_data, test_loader = self._get_data(flag='test')
        if self.args.rank == 1:
            train_data, train_loader = self._get_data(flag='train')

        path = os.path.join(self.args.path, str(ii))
        try:
            os.mkdir(path)
        except FileExistsError:
            pass
        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True,
                                       rank=self.args.rank)

        W_optim, A_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            rate_counter = AverageMeter()
            Ag_counter, A_counter, Wg_counter, W_counter = AverageMeter(
            ), AverageMeter(), AverageMeter(), AverageMeter()

            self.model.train()
            epoch_time = time.time()
            for i, (trn_data, val_data, next_data) in enumerate(
                    zip(train_loader, vali_loader, next_loader)):
                for i in range(len(trn_data)):
                    trn_data[i], val_data[i], next_data[i] = trn_data[i].float(
                    ).to(self.device), val_data[i].float().to(
                        self.device), next_data[i].float().to(self.device)
                iter_count += 1
                A_optim.zero_grad()
                rate = self.arch.unrolled_backward(
                    self.args, trn_data, val_data, next_data,
                    W_optim.param_groups[0]['lr'], W_optim)
                rate_counter.update(rate)
                # for r in range(1, self.args.world_size):
                #     for n, h in self.model.named_H():
                #         if "proj.{}".format(r) in n:
                #             if self.args.rank <= r:
                #                 with torch.no_grad():
                #                     dist.all_reduce(h.grad)
                #                     h.grad *= self.args.world_size/r+1
                #             else:
                #                 z = torch.zeros(h.shape).to(self.device)
                #                 dist.all_reduce(z)
                for a in self.model.A():
                    with torch.no_grad():
                        dist.all_reduce(a.grad)
                a_g_norm = 0
                a_norm = 0
                n = 0
                for a in self.model.A():
                    a_g_norm += a.grad.mean()
                    a_norm += a.mean()
                    n += 1
                Ag_counter.update(a_g_norm / n)
                A_counter.update(a_norm / n)

                A_optim.step()

                W_optim.zero_grad()
                pred, true = self._process_one_batch(train_data, trn_data)
                loss = criterion(pred, true)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    logger.info(
                        "\tR{0} iters: {1}, epoch: {2} | loss: {3:.7f}".format(
                            self.args.rank, i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    logger.info(
                        '\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                            speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(W_optim)
                    scaler.update()
                else:
                    loss.backward()

                    w_g_norm = 0
                    w_norm = 0
                    n = 0
                    for w in self.model.W():
                        w_g_norm += w.grad.mean()
                        w_norm += w.mean()
                        n += 1
                    Wg_counter.update(w_g_norm / n)
                    W_counter.update(w_norm / n)

                    W_optim.step()

            logger.info("R{} Epoch: {} W:{} Wg:{} A:{} Ag:{} rate{}".format(
                self.args.rank, epoch + 1, W_counter.avg, Wg_counter.avg,
                A_counter.avg, Ag_counter.avg, rate_counter.avg))

            logger.info("R{} Epoch: {} cost time: {}".format(
                self.args.rank, epoch + 1,
                time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            logger.info(
                "R{0} Epoch: {1}, Steps: {2} | Train Loss: {3:.7f} Vali Loss: {4:.7f} Test Loss: {5:.7f}"
                .format(self.args.rank, epoch + 1, train_steps, train_loss,
                        vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)

            flag = torch.tensor(
                [1]) if early_stopping.early_stop else torch.tensor([0])
            flag = flag.to(self.device)
            flags = [
                torch.tensor([1]).to(self.device),
                torch.tensor([1]).to(self.device)
            ]
            dist.all_gather(flags, flag)
            if flags[0].item() == 1 and flags[1].item() == 1:
                logger.info("Early stopping")
                break

            adjust_learning_rate(W_optim, epoch + 1, self.args)

        best_model_path = path + '/' + '{}_checkpoint.pth'.format(
            self.args.rank)
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model

    def test(self, setting, logger):
        test_data, test_loader = self._get_data(flag='test')

        self.model.eval()

        preds = []
        trues = []

        for i, test_d in enumerate(test_loader):
            pred, true = self._process_one_batch(test_data, test_d)
            preds.append(pred.detach().cpu().numpy())
            trues.append(true.detach().cpu().numpy())

        preds = np.array(preds)
        trues = np.array(trues)
        logger.info('test shape: {} {}'.format(preds.shape, trues.shape))
        preds = preds.reshape((-1, preds.shape[-2], preds.shape[-1]))
        trues = trues.reshape((-1, trues.shape[-2], trues.shape[-1]))
        logger.info('test shape: {} {}'.format(preds.shape, trues.shape))

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        mae, mse, rmse, mape, mspe = metric(preds, trues)
        logger.info('R{} mse:{}, mae:{}'.format(self.args.rank, mse, mae))

        np.save(folder_path + 'metrics.npy',
                np.array([mae, mse, rmse, mape, mspe]))
        np.save(folder_path + 'pred.npy', preds)
        np.save(folder_path + 'true.npy', trues)

        return

    def predict(self, setting, load=False):
        pred_data, pred_loader = self._get_data(flag='pred')

        if load:
            path = os.path.join(self.args.checkpoints, setting)
            best_model_path = path + '/' + 'checkpoint.pth'
            self.model.load_state_dict(torch.load(best_model_path))

        self.model.eval()

        preds = []

        for i, pred_d in enumerate(pred_loader):
            pred, true = self._process_one_batch(pred_data, pred_d)
            preds.append(pred.detach().cpu().numpy())

        preds = np.array(preds)
        preds = preds.reshape((-1, preds.shape[-2], preds.shape[-1]))

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        np.save(folder_path + 'real_prediction.npy', preds)

        return

    def _process_one_batch(self, dataset_object, data):
        batch_x = data[0].float().to(self.device)
        batch_y = data[1].float().to(self.device)

        batch_x_mark = data[2].float().to(self.device)
        batch_y_mark = data[3].float().to(self.device)

        # decoder input
        if self.args.padding == 0:
            dec_inp = torch.zeros(
                [batch_y.shape[0], self.args.pred_len,
                 batch_y.shape[-1]]).float().to(self.device)
        elif self.args.padding == 1:
            dec_inp = torch.ones(
                [batch_y.shape[0], self.args.pred_len,
                 batch_y.shape[-1]]).float().to(self.device)
        dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp],
                            dim=1).float().to(self.device)
        # encoder - decoder
        if self.args.use_amp:
            with torch.cuda.amp.autocast():
                if self.args.output_attention:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                         batch_y_mark)[0]
                else:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                         batch_y_mark)[0]
        else:
            if self.args.output_attention:
                outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                     batch_y_mark)[0]
            else:
                outputs = self.model(batch_x, batch_x_mark, dec_inp,
                                     batch_y_mark)[0]
        if self.args.inverse:
            outputs = dataset_object.inverse_transform(outputs)
        f_dim = -1 if self.args.features == 'MS' else 0
        batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)

        return outputs, batch_y
Ejemplo n.º 7
0
class Searcher(object):
    def __init__(self, args):

        self.args = args
        self.console = Console()

        self.console.log('=> [1] Initial settings')
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        cudnn.benchmark = True
        cudnn.enabled = True

        self.console.log('=> [2] Initial models')
        self.metric = load_metric(args)
        self.loss_fn = get_loss_fn(args).cuda()
        self.model = Model_Search(args, get_trans_input(args),
                                  self.loss_fn).cuda()
        self.console.log(
            f'=> Supernet Parameters: {count_parameters_in_MB(self.model)}',
            style='bold red')

        self.console.log(f'=> [3] Preparing dataset')
        self.dataset = load_data(args)
        if args.pos_encode > 0:
            #! add positional encoding
            self.console.log(f'==> [3.1] Adding positional encodings')
            self.dataset._add_positional_encodings(args.pos_encode)
        self.search_data = self.dataset.train
        self.val_data = self.dataset.val
        self.test_data = self.dataset.test
        self.load_dataloader()

        self.console.log(f'=> [4] Initial optimizer')
        self.optimizer = torch.optim.SGD(params=self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=self.optimizer,
            T_max=float(args.epochs),
            eta_min=args.lr_min)

        self.architect = Architect(self.model, self.args)

    def load_dataloader(self):

        num_search = int(len(self.search_data) * self.args.data_clip)
        indices = list(range(num_search))
        split = int(np.floor(self.args.portion * num_search))
        self.console.log(
            f'=> Para set size: {split}, Arch set size: {num_search - split}')

        self.para_queue = torch.utils.data.DataLoader(
            dataset=self.search_data,
            batch_size=self.args.batch,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[:split]),
            pin_memory=True,
            num_workers=self.args.nb_workers,
            collate_fn=self.dataset.collate)

        self.arch_queue = torch.utils.data.DataLoader(
            dataset=self.search_data,
            batch_size=self.args.batch,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[split:]),
            pin_memory=True,
            num_workers=self.args.nb_workers,
            collate_fn=self.dataset.collate)

        num_valid = int(len(self.val_data) * self.args.data_clip)
        indices = list(range(num_valid))

        self.val_queue = torch.utils.data.DataLoader(
            dataset=self.val_data,
            batch_size=self.args.batch,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices),
            pin_memory=True,
            num_workers=self.args.nb_workers,
            collate_fn=self.dataset.collate)

        num_test = int(len(self.test_data) * self.args.data_clip)
        indices = list(range(num_test))

        self.test_queue = torch.utils.data.DataLoader(
            dataset=self.test_data,
            batch_size=self.args.batch,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices),
            pin_memory=True,
            num_workers=self.args.nb_workers,
            collate_fn=self.dataset.collate)

    def run(self):

        self.console.log(f'=> [4] Search & Train')
        for i_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.lr = self.scheduler.get_lr()[0]
            if i_epoch % self.args.report_freq == 0:
                geno = genotypes(
                    args=self.args,
                    arch_paras=self.model.group_arch_parameters(),
                    arch_topos=self.model.cell_arch_topo,
                )
                with open(
                        f'{self.args.arch_save}/{self.args.data}/{i_epoch}.yaml',
                        "w") as f:
                    yaml.dump(geno, f)

                # => report genotype
                self.console.log(geno)
                for i in range(self.args.nb_layers):
                    for p in self.model.group_arch_parameters()[i]:
                        self.console.log(p.softmax(0).detach().cpu().numpy())

            search_result = self.search()
            self.console.log(
                f"[green]=> search result [{i_epoch}] - loss: {search_result['loss']:.4f} - metric : {search_result['metric']:.4f}",
            )
            # DecayScheduler().step(i_epoch)

            with torch.no_grad():
                val_result = self.infer(self.val_queue)
                self.console.log(
                    f"[yellow]=> valid result  [{i_epoch}] - loss: {val_result['loss']:.4f} - metric : {val_result['metric']:.4f}"
                )

                test_result = self.infer(self.test_queue)
                self.console.log(
                    f"[red]=> test  result  [{i_epoch}] - loss: {test_result['loss']:.4f} - metric : {test_result['metric']:.4f}"
                )

    def search(self):

        self.model.train()
        epoch_loss = 0
        epoch_metric = 0
        desc = '=> searching'
        device = torch.device('cuda')

        with tqdm(self.para_queue, desc=desc, leave=False) as t:
            for i_step, (batch_graphs, batch_targets) in enumerate(t):
                #! 1. preparing training datasets
                G = batch_graphs.to(device)
                V = batch_graphs.ndata['feat'].to(device)
                # E = batch_graphs.edata['feat'].to(device)
                batch_targets = batch_targets.to(device)
                #! 2. preparing validating datasets
                batch_graphs_search, batch_targets_search = next(
                    iter(self.arch_queue))
                GS = batch_graphs_search.to(device)
                VS = batch_graphs_search.ndata['feat'].to(device)
                # ES = batch_graphs_search.edata['feat'].to(device)
                batch_targets_search = batch_targets_search.to(device)
                #! 3. optimizing architecture topology parameters
                self.architect.step(input_train={
                    'G': G,
                    'V': V
                },
                                    target_train=batch_targets,
                                    input_valid={
                                        'G': GS,
                                        'V': VS
                                    },
                                    target_valid=batch_targets_search,
                                    eta=self.lr,
                                    network_optimizer=self.optimizer,
                                    unrolled=self.args.unrolled)
                #! 4. optimizing model parameters
                self.optimizer.zero_grad()
                batch_scores = self.model({'G': G, 'V': V})
                loss = self.loss_fn(batch_scores, batch_targets)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.detach().item()
                epoch_metric += self.metric(batch_scores, batch_targets)
                t.set_postfix(lr=self.lr,
                              loss=epoch_loss / (i_step + 1),
                              metric=epoch_metric / (i_step + 1))

        return {
            'loss': epoch_loss / (i_step + 1),
            'metric': epoch_metric / (i_step + 1)
        }

    def infer(self, dataloader):

        self.model.eval()
        epoch_loss = 0
        epoch_metric = 0
        desc = '=> inferring'
        device = torch.device('cuda')

        with tqdm(dataloader, desc=desc, leave=False) as t:
            for i_step, (batch_graphs, batch_targets) in enumerate(t):
                G = batch_graphs.to(device)
                V = batch_graphs.ndata['feat'].to(device)
                # E = batch_graphs.edata['feat'].to(device)
                batch_targets = batch_targets.to(device)
                batch_scores = self.model({'G': G, 'V': V})
                loss = self.loss_fn(batch_scores, batch_targets)

                epoch_loss += loss.detach().item()
                epoch_metric += self.metric(batch_scores, batch_targets)
                t.set_postfix(loss=epoch_loss / (i_step + 1),
                              metric=epoch_metric / (i_step + 1))

        return {
            'loss': epoch_loss / (i_step + 1),
            'metric': epoch_metric / (i_step + 1)
        }
class SearchDistributionTrainer(SearchStageTrainer):
    def __init__(self, config):
        super().__init__(config)

    def construct_model(self):
        """get data loader"""
        input_size, input_channels, n_classes, train_data = get_data(
            self.config.dataset,
            self.config.data_path,
            cutout_length=0,
            validation=False)

        n_train = len(train_data)
        split = n_train // 2
        indices = list(range(n_train))
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[:split])
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:])
        self.train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.config.batch_size,
            sampler=train_sampler,
            num_workers=self.config.workers,
            pin_memory=True)
        self.valid_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.config.batch_size,
            sampler=valid_sampler,
            num_workers=self.config.workers,
            pin_memory=True)
        """build model"""
        print("init model")
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        model = SearchDistributionController(input_channels,
                                             self.config.init_channels,
                                             n_classes,
                                             self.config.layers,
                                             self.criterion,
                                             self.config.genotype,
                                             device_ids=self.config.gpus)
        self.model = model.to(self.device)
        print("init model end!")
        """build optimizer"""
        print("get optimizer")
        self.w_optim = torch.optim.SGD(self.model.weights(),
                                       self.config.w_lr,
                                       momentum=self.config.w_momentum,
                                       weight_decay=self.config.w_weight_decay)
        self.alpha_optim = torch.optim.Adam(
            self.model.alphas(),
            self.config.alpha_lr,
            betas=(0.5, 0.999),
            weight_decay=self.config.alpha_weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.w_optim, self.total_epochs, eta_min=self.config.w_lr_min)
        self.architect = Architect(self.model, self.config.w_momentum,
                                   self.config.w_weight_decay)

    def cal_depth(self, alpha, n_nodes, SW, beta):
        assert len(
            alpha
        ) == n_nodes, "the length of alpha must be the same as n_nodes"

        d = [0, 0]
        for i, edges in enumerate(alpha):
            edge_max, _ = torch.topk(edges[:, :-1], 1)
            edge_max = F.softmax(edge_max, dim=0)
            if i < SW - 2:
                dd = 0
                for j in range(i + 2):
                    dd += edge_max[j][0] * (d[j] + 1)
                dd /= (i + 2)
            else:
                dd = 0
                for s, j in enumerate(range(i - 1, i + 2)):
                    dd += edge_max[s][0] * (d[j] + 1)
                dd /= SW
            if i >= 3:
                dd *= (1 + i * beta[i - 3])[0]
            d.append(dd)
        return sum(d) / n_nodes

    def concat_param_loss(self, beta):
        loss = sum([beta[i][j] * (j + 4) for i in range(3) for j in range(5)])
        return loss

    def train_epoch(self, epoch, printer):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        cur_lr = self.lr_scheduler.get_last_lr()[0]

        self.model.print_alphas(self.logger)
        self.model.train()

        prefetcher_trn = data_prefetcher(self.train_loader)
        prefetcher_val = data_prefetcher(self.valid_loader)
        trn_X, trn_y = prefetcher_trn.next()
        val_X, val_y = prefetcher_val.next()
        i = 0
        while trn_X is not None:
            i += 1
            N = trn_X.size(0)
            self.steps += 1

            # architect step (alpha)
            self.alpha_optim.zero_grad()
            self.architect.unrolled_backward(trn_X, trn_y, val_X, val_y,
                                             cur_lr, self.w_optim)
            self.alpha_optim.step()

            self.alpha_optim.zero_grad()
            alpha = self.architect.net.alpha_DAG
            beta = [
                F.softmax(be, dim=0) for be in self.architect.net.alpha_concat
            ]
            self.n_nodes = self.config.layers // 3
            d_depth1 = self.cal_depth(alpha[0 * self.n_nodes:1 * self.n_nodes],
                                      self.n_nodes, 3, beta[0])
            d_depth2 = self.cal_depth(alpha[1 * self.n_nodes:2 * self.n_nodes],
                                      self.n_nodes, 3, beta[1])
            d_depth3 = self.cal_depth(alpha[2 * self.n_nodes:3 * self.n_nodes],
                                      self.n_nodes, 3, beta[2])
            depth_loss = -1 * (d_depth1 + d_depth2 + d_depth3)
            param_loss = self.concat_param_loss(beta)
            new_loss = depth_loss + 0.4 * param_loss
            new_loss.backward()
            self.alpha_optim.step()

            # child network step (w)
            self.w_optim.zero_grad()
            logits = self.model(trn_X)
            loss = self.model.criterion(logits, trn_y)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.weights(),
                                     self.config.w_grad_clip)
            self.w_optim.step()

            prec1, prec5 = accuracy(logits, trn_y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if self.steps % self.log_step == 0:
                self.writer.add_scalar('train/lr', round(cur_lr, 5),
                                       self.steps)
                self.writer.add_scalar('train/loss', loss.item(), self.steps)
                self.writer.add_scalar('train/top1', prec1.item(), self.steps)
                self.writer.add_scalar('train/top5', prec5.item(), self.steps)

            if i % self.config.print_freq == 0 or i == len(
                    self.train_loader) - 1:
                printer(
                    f'Train: Epoch: [{epoch}][{i}/{len(self.train_loader) - 1}]\t'
                    f'Step {self.steps}\t'
                    f'lr {round(cur_lr, 5)}\t'
                    f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                    f'Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})\t')

            trn_X, trn_y = prefetcher_trn.next()
            val_X, val_y = prefetcher_val.next()

        printer("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(
            epoch, self.total_epochs - 1, top1.avg))