Пример #1
0
def main(FLAG):
    Model = SimpleModel(FLAG.input_dim,
                        FLAG.hidden_dim,
                        FLAG.output_dim,
                        optimizer=tf.train.RMSPropOptimizer(
                            FLAG.learning_rate))

    image, label = load_dataset()
    image, label = image_augmentation(image,
                                      label,
                                      horizon_flip=True,
                                      control_brightness=True)
    label = label / 96.
    (train_X, train_y), (valid_X,
                         valid_y), (test_X, test_y) = split_data(image, label)

    if FLAG.Mode == "validation":
        lr_list = 10**np.random.uniform(-6, -2, 20)
        Model.validation(train_X, train_y, valid_X, valid_y, lr_list)
    elif FLAG.Mode == "train":
        Model.train(train_X, train_y, valid_X, valid_y, FLAG.batch_size,
                    FLAG.Epoch, FLAG.save_graph, FLAG.save_model)

        pred_Y = Model.predict(test_X[123])
        print(pred_Y)
        print(test_y[123])
        print(np.mean(np.square(pred_Y - test_y[123])))
Пример #2
0
model = SimpleModel().to(device)

# loss function
criterion = nn.CrossEntropyLoss()

#optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

#training
num_steps = len(train_loader)

for epoch in range(num_epochs):
    #----training----
    #set model to training

    model.train()

    total_loss = 0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        #zero gradients
        optimizer.zero_grad()

        #forward
        outputs = model(images)

        #compute loss
        loss = criterion(outputs, labels)
Пример #3
0
class MultiTaskSingleObjectiveAgent(SingleTaskSingleObjectiveAgent):
    def __init__(self, architecture, search_space, task_info):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.search_size = len(search_space)
        self.num_tasks = task_info.num_tasks

        self.model = SimpleModel(architecture=architecture,
                                 search_space=search_space,
                                 in_channels=task_info.num_channels,
                                 num_classes=task_info.num_classes)
        self.compute_model_size = SimpleModelSize(architecture,
                                                  search_space,
                                                  task_info.num_channels,
                                                  sum(task_info.num_classes),
                                                  batchnorm=True)

        self._init()

    def _pretrain(self,
                  train_data,
                  test_data,
                  configs,
                  save_model=False,
                  save_history=False,
                  path='saved_models/default/pretrain/',
                  verbose=False):

        self.model.train()

        dataloader = train_data.get_loader()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(),
                              lr=configs.lr,
                              momentum=configs.momentum,
                              weight_decay=configs.weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=configs.lr_decay_epoch,
            gamma=configs.lr_decay)

        for epoch in range(self.epoch['pretrain']):
            scheduler.step()

        for epoch in range(self.epoch['pretrain'], configs.num_epochs):
            scheduler.step()
            dropout = configs.dropout * epoch / configs.num_epochs

            for inputs, labels, task in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                masks = self.mask_sampler.rand(dropout=dropout)
                outputs = self.model(inputs,
                                     self.mask_sampler.make_batch(masks), task)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if verbose or save_history:
                masks = self.mask_sampler.ones()
                self.accuracy['pretrain'].append(
                    self._eval_model(test_data, masks))

            if verbose:
                print('[Pretrain][Epoch {}] Accuracy: {}'.format(
                    epoch + 1, self.accuracy['pretrain'][-1]))

            if epoch % configs.save_epoch == 0 and save_model:
                self._save_pretrain(path)
                self.epoch['pretrain'] = epoch + 1
                self._save_epoch('pretrain', path)

        if save_model:
            self._save_pretrain(path)
            self.epoch['pretrain'] = configs.num_epochs
            self._save_epoch('pretrain', path)

    def _finaltrain(self,
                    train_data,
                    test_data,
                    configs,
                    save_model=False,
                    save_history=False,
                    path='saved_models/default/final/',
                    verbose=False):

        if self.finalmodel is None:
            self.finalmodel_mask = self.queue[0]
            self.finalmodel = [
                self.submodel(self.finalmodel_mask, task)
                for task in range(self.num_tasks)
            ]
            self.finalmodel = [
                nn.DataParallel(m).to(self.device) for m in self.finalmodel
            ]

        for model in self.finalmodel:
            model.train()

        dataloader = train_data.get_loader()
        criterion = nn.CrossEntropyLoss()
        optimizers = [
            optim.SGD(model.parameters(),
                      lr=configs.lr,
                      momentum=configs.momentum,
                      weight_decay=configs.weight_decay)
            for model in self.finalmodel
        ]
        schedulers = [
            optim.lr_scheduler.MultiStepLR(optimizer,
                                           milestones=configs.lr_decay_epoch,
                                           gamma=configs.lr_decay)
            for optimizer in optimizers
        ]

        for epoch in range(self.epoch['final']):
            for scheduler in schedulers:
                scheduler.step()

        for epoch in range(self.epoch['final'], configs.num_epochs):
            for scheduler in schedulers:
                scheduler.step()

            for inputs, labels, task in dataloader:
                model = self.finalmodel[task]
                optimizer = optimizers[task]

                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if verbose or save_history:
                self.accuracy['final'].append(self._eval_final(test_data))

            if verbose:
                print('[Final][Epoch {}] Accuracy: {}'.format(
                    epoch + 1, self.accuracy['final'][-1]))

            if epoch % configs.save_epoch == 0:
                if save_model:
                    self._save_final(path)
                    self.epoch['final'] = epoch + 1
                    self._save_epoch('final', path)

                if save_history:
                    self._save_accuracy('final', path)

        if save_model:
            self._save_final(path)
            self.epoch['final'] = configs.num_epochs
            self._save_epoch('final', path)

        if save_history:
            self._save_accuracy('final', path)

    def _eval_model(self, data, masks):
        masks = self.mask_sampler.make_batch(masks)
        model = lambda x, t: self.model(x, masks, t)
        accuracy = self._eval(data, model)

        return accuracy

    def _eval_final(self, data):
        for model in self.finalmodel:
            model.eval()

        model = lambda x, t: self.finalmodel[t](x)
        accuracy = self._eval(data, model)

        for model in self.finalmodel:
            model.train()

        return accuracy

    def _eval(self, data, model):
        correct = [0 for _ in range(self.num_tasks)]
        total = [0 for _ in range(self.num_tasks)]

        with torch.no_grad():
            for t in range(self.num_tasks):
                for inputs, labels in data.get_loader(t):
                    inputs, labels = inputs.to(self.device), labels.to(
                        self.device)
                    outputs = model(inputs, t)
                    _, predict_labels = torch.max(outputs.detach(), 1)

                    total[t] += labels.size(0)
                    correct[t] += (predict_labels == labels).sum().item()

            return np.mean([c / t for c, t in zip(correct, total)])

    def _save_final(self, path='saved_models/default/final/'):
        if not os.path.isdir(path):
            os.makedirs(path)

        with open(os.path.join(path, 'masks.json'), 'w') as f:
            json.dump(self.finalmodel_mask.tolist(), f)

        for t, model in enumerate(self.finalmodel):
            torch.save(model.state_dict(),
                       os.path.join(path, 'model{}'.format(t)))

    def _load_final(self, path='saved_models/default/final/'):
        try:
            with open(os.path.join(path, 'masks.json'), 'r') as f:
                self.finalmodel_mask = json.load(f)
            self.finalmodel_mask = torch.tensor(self.finalmodel_mask,
                                                dtype=torch.uint8)

            self.finalmodel = [
                self.submodel(self.finalmodel_mask, task)
                for task in range(self.num_tasks)
            ]
            self.finalmodel = [
                nn.DataParallel(m).to(self.device) for m in self.finalmodel
            ]

            for t, model in enumerate(self.finalmodel):
                filename = os.path.join(path, 'model{}'.format(t))
                model.load_state_dict(torch.load(filename))

        except FileNotFoundError:
            pass
Пример #4
0
def train(args):
    """
        Terminology: k-way n-shot, k classes, n shots per class
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    renew_path(args.save)

    shots = args.shot + args.query
    train_set = IkeaSet(TRAIN_MAT_PATH)
    train_sampler = Sampler(train_set.label,
                            args.batch_num_train,
                            args.train_way,
                            shots,
                            limit_class=args.limit_class)
    train_loader = DataLoader(train_set,
                              batch_sampler=train_sampler,
                              num_workers=4,
                              pin_memory=True)

    test_set = IkeaSet(TEST_MAT_PATH)
    test_sampler = Sampler(test_set.label,
                           args.batch_num_test,
                           args.test_way,
                           shots,
                           limit_class=args.limit_class)
    test_loader = DataLoader(test_set,
                             batch_sampler=test_sampler,
                             num_workers=4,
                             pin_memory=True)

    model = SimpleModel().to(device)
    model = load_model(model, 'model', args.save)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    # learing rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=20,
                                                   gamma=0.5)
    loss_fn = F.cross_entropy

    # training log
    training_log = {}
    training_log['args'] = []
    training_log['train_loss'] = []
    training_log['val_loss'] = []
    training_log['train_acc'] = []
    training_log['val_acc'] = []
    training_log['max_acc'] = 0.0

    for epoch in range(1, args.epoch + 1):
        time_a = datetime.datetime.now()
        model.train()
        average_loss = 0
        average_accuracy = 0
        print("Start epoch: ", epoch)
        for i, batch in enumerate(train_loader, 1):
            num = args.shot * args.train_way
            support_x, query_x = batch[0][:num].to(device), batch[0][num:].to(
                device)
            #support_y, query_y = batch[1][:num], batch[1][num:]
            #print(support_x.shape)
            embedding = model(support_x.float())

            # Get the mean of all the embeddings to get the prototype for a class
            embedding = embedding.reshape(args.shot, args.train_way,
                                          -1).mean(dim=0)
            #print(batch[0].shape)

            # Tough it seems strange here to just use labels in range but instead of real lables
            # , but that is beacause of the way the data was sampled (see sampled.py for structure
            # of a batch). The real label of the data does not correspond to the index of the closest
            # cluster center since the samples in the batch are shuffled, so instead we transform the data
            # label into the relative index in the range of classes, in this way the closest cluster
            # center index matches the relative index.
            label = torch.arange(args.train_way).repeat(args.query)
            #label = query_y.type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor)
            label = label.type(torch.cuda.LongTensor if torch.cuda.
                               is_available() else torch.LongTensor)

            distance = euclidean(model(query_x), embedding)
            prob = F.softmax(distance, dim=1)

            loss = loss_fn(prob, label)
            acc = get_accuracy(label, prob)
            if i % 30 == 0:
                print(label.shape, distance.shape)
                print('epoch{}, {}/{}, lost={:.4f} acc={:.4f}'.format(
                    epoch, i, len(train_loader), loss.item(), acc))
            average_loss = update_avg(i + 1, average_loss, loss.item())
            average_accuracy = update_avg(i + 1, average_accuracy, acc)

            #optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            embedding = None
            loss = None
            distanece = None

        model.eval()
        average_loss_val = 0
        average_accuracy_val = 0

        # evaluate after epoch.
        with torch.no_grad():
            for i, batch in enumerate(test_loader, 1):
                num = args.shot * args.test_way
                support_x, query_x = batch[0][:num].to(
                    device), batch[0][num:].to(device)
                #support_y, query_y = batch[1][:num], batch[1][num:]
                embedding = model(support_x)
                embedding = embedding.reshape(args.shot, args.test_way,
                                              -1).mean(dim=0)

                label = torch.arange(args.train_way).repeat(args.query)
                #label = query_y.type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor)
                label = label.type(torch.cuda.LongTensor if torch.cuda.
                                   is_available() else torch.LongTensor)
                distance = euclidean(model(query_x), embedding)
                prob = F.softmax(distance, dim=1)

                loss = loss_fn(prob, label)
                acc = get_accuracy(label, prob)
                average_loss_val = update_avg(i + 1, average_loss_val,
                                              loss.item())
                average_accuracy_val = update_avg(i + 1, average_accuracy_val,
                                                  acc)

                embedding = None
                loss = None
                distanece = None

        print("epoch {} validation: loss={:4f} acc={:4f}".format(
            epoch, average_loss, average_accuracy))
        if average_accuracy > training_log['max_acc']:
            training_log['max_acc'] = acc
            save_model(model, 'max-acc', args.save)

        training_log['train_loss'].append(average_loss)
        training_log['train_acc'].append(average_accuracy)
        training_log['val_loss'].append(average_loss_val)
        training_log['val_acc'].append(average_accuracy_val)

        torch.save(training_log, os.path.join(args.save, 'training_log'))
        save_model(model, 'model', args.save)

        if epoch % 1 == 0:
            save_model(model, 'model', args.save)

        time_b = datetime.datetime.now()
        print('ETA:{}s/{}s'.format(
            (time_b - time_a).seconds,
            (time_b - time_a).seconds * (args.epoch - epoch)))
Пример #5
0
def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    renew_path(args.save)
    
    train_set = IkeaSet(TRAIN_MAT_PATH)
    train_loader = DataLoader(train_set, num_workers=4, pin_memory=True)

    test_set = IkeaSet(TEST_MAT_PATH)
    test_loader = DataLoader(test_set, num_workers=4, pin_memory=True)

    model = SimpleModel(n_class=len(train_set.classes)).to(device)
    model = load_model(model, 'model', args.save)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    # learing rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    loss_fn = F.cross_entropy
    
    # training log
    training_log = {}
    training_log['args'] = []
    training_log['train_loss'] = []
    training_log['val_loss'] = []
    training_log['train_acc'] = []
    training_log['val_acc'] = []
    training_log['max_acc'] = 0.0

    for epoch in range(1, args.epoch + 1):
        time_a = datetime.datetime.now()
        model.train()
        average_loss = 0
        average_accuracy = 0
        for i, batch in enumerate(train_loader, 1):
            batch[0].to(device), batch[1].to(device)
            
            label = batch[1]
            label = label.type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor)
            
            pred = model(batch[0])

            loss = loss_fn(pred, label)
            print(pred)
            print(torch.argmax(pred, dim=1), label)
            acc = get_accuracy(label, pred)
            if i % 20 == 0:
                print('epoch{}, {}/{}, lost={:.4f} acc={:.4f}'.format(epoch, i, len(train_loader), loss.item(), acc))
            average_loss = update_avg(i + 1, average_loss, loss.item())
            average_accuracy = update_avg(i + 1, average_accuracy, acc)

            #optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            embedding = None
            loss = None
            distanece = None
        
        model.eval()
        average_loss_val = 0
        average_accuracy_val = 0

        # evaluate after epoch.
        with torch.no_grad():
            for i, batch in enumerate(test_loader, 1):
                batch[0].to(device), batch[1].to(device)

                label = batch[1]
                #label = query_y.type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor)
                label = label.type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor)
                
                pred = model(batch[0])
                
                loss = loss_fn(pred, label)
                acc = get_accuracy(label, pred)
                average_loss_val = update_avg(i + 1, average_loss_val, loss.item())
                average_accuracy_val = update_avg(i + 1, average_accuracy_val, acc)

                embedding = None
                loss = None
                distanece = None

        print("epoch {} validation: loss={:4f} acc={:4f}".format(epoch, average_loss, average_accuracy))
        if average_accuracy > training_log['max_acc']:
            training_log['max_acc'] = acc
            save_model(model, 'max-acc', args.save)

        training_log['train_loss'].append(average_loss)
        training_log['train_acc'].append(average_accuracy)
        training_log['val_loss'].append(average_loss_val)
        training_log['val_acc'].append(average_accuracy_val)

        torch.save(training_log, os.path.join(args.save, 'training_log'))
        save_model(model, 'model', args.save)

        if epoch % 1 == 0:
            save_model(model, 'model', args.save)
        
        time_b = datetime.datetime.now()
        print('ETA:{}s/{}s'.format((time_b - time_a).seconds, (time_b - time_a).seconds * (args.epoch - epoch)))
Пример #6
0
model_pytorch = SimpleModel(input_size=input_size, hidden_sizes=hidden_sizes,
                            output_size=output_size)
model_pytorch = model_pytorch.to(device)

# Set loss and optimizer
# Set binary cross entropy loss since 2 classes only
criterion = nn.BCELoss()
optimizer = optim.Adam(model_pytorch.parameters(), lr=1e-3)

num_epochs = 20

# Train model
time_start = time.time()

for epoch in range(num_epochs):
    model_pytorch.train()

    train_loss_total = 0

    for data, target in train_loader:
        data, target = data.to(device), target.float().to(device)
        optimizer.zero_grad()
        output = model_pytorch(data)
        train_loss = criterion(output, target)
        train_loss.backward()
        optimizer.step()
        train_loss_total += train_loss.item() * data.size(0)

    print('Epoch {} completed. Train loss is {:.3f}'.format(epoch + 1, train_loss_total / train_size))
print('Time taken to completed {} epochs: {:.2f} minutes'.format(num_epochs, (time.time() - time_start) / 60))
Пример #7
0
class SingleTaskSingleObjectiveAgent(BaseAgent):
    def __init__(self, architecture, search_space, task_info):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.search_size = len(search_space)

        self.model = SimpleModel(architecture=architecture,
                                 search_space=search_space,
                                 in_channels=task_info.num_channels,
                                 num_classes=[task_info.num_classes]
                                 )
        self.compute_model_size = SimpleModelSize(architecture, search_space, task_info.num_channels, task_info.num_classes, batchnorm=True)

        self._init()


    def _init(self):
        self.submodel = self.model.submodel
        self.mask_sampler = MaskSampler(mask_size=self.model.mask_size)
        self.model = nn.DataParallel(self.model).to(self.device)

        # Record

        self.epoch = {'pretrain': 0, 'search': 0, 'final': 0}
        self.accuracy = {'pretrain': [], 'search': [], 'final': []}

        # Search

        self.accuracy_dict_valid = {}
        self.accuracy_dict_test = {}
        self.queue = []
        self.queue_acc = []

        # Final

        self.finalmodel_mask = None
        self.finalmodel = None



    def train(self, train_data, valid_data, test_data, configs, save_model, save_history, path, verbose):

        # Pretrain

        if self.epoch['pretrain'] < configs.pretrain.num_epochs:
            self._pretrain(train_data=train_data,
                           test_data=test_data,
                           configs=configs.pretrain,
                           save_model=save_model,
                           save_history=save_history,
                           path=os.path.join(path, 'pretrain'),
                           verbose=verbose
                           )

        # Select final model

        if self.epoch['search'] < configs.search.num_epochs:
            self._search(valid_data=valid_data,
                         test_data=test_data,
                         configs=configs.search,
                         save_model=save_model,
                         save_history=save_history,
                         path=os.path.join(path, 'search'),
                         verbose=verbose
                         )

        # Train final model

        if self.epoch['final'] < configs.final.num_epochs:
            self._finaltrain(train_data=train_data,
                             test_data=test_data,
                             configs=configs.final,
                             save_model=save_model,
                             save_history=save_history,
                             path=os.path.join(path, 'final'),
                             verbose=verbose
                             )


    def _pretrain(self,
                  train_data,
                  test_data,
                  configs,
                  save_model=False,
                  save_history=False,
                  path='saved_models/default/pretrain/',
                  verbose=False
                  ):

        self.model.train()

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(), lr=configs.lr, momentum=configs.momentum, weight_decay=configs.weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=configs.lr_decay_epoch, gamma=configs.lr_decay)

        for epoch in range(self.epoch['pretrain']):
            scheduler.step()

        for epoch in range(self.epoch['pretrain'], configs.num_epochs):
            scheduler.step()
            dropout = configs.dropout * epoch / configs.num_epochs

            for inputs, labels in train_data:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                masks = self.mask_sampler.rand(dropout=dropout)
                outputs = self.model(inputs, self.mask_sampler.make_batch(masks))
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if verbose or save_history:
                masks = self.mask_sampler.ones()
                self.accuracy['pretrain'].append(self._eval_model(test_data, masks))

            if verbose:
                print('[Pretrain][Epoch {}] Accuracy: {}'.format(epoch + 1, self.accuracy['pretrain'][-1]))

            if epoch % configs.save_epoch == 0 and save_model:
                self._save_pretrain(path)
                self.epoch['pretrain'] = epoch + 1
                self._save_epoch('pretrain', path)

        if save_model:
            self._save_pretrain(path)
            self.epoch['pretrain'] = configs.num_epochs
            self._save_epoch('pretrain', path)


    def _search(self,
                valid_data,
                test_data,
                configs,
                save_model=False,
                save_history=False,
                path='saved_models/default/search/',
                verbose=False
                ):

        # Initalization

        if self.epoch['search'] == 0:
            self.queue = [self.mask_sampler.rand(dropout=i/(configs.num_samples-1)) for i in range(configs.num_samples)]
            self.queue_acc = []

            for masks in self.queue:
                masks_str = masks2str(masks)
                accuracy = self._eval_model(valid_data, masks)
                self.accuracy_dict_valid[masks_str] = accuracy
                self.queue_acc.append(accuracy)

        # Search

        for epoch in range(self.epoch['search'], configs.num_epochs):
            generated = []
            generated_acc = []

            for old_masks in self.queue:
                new_masks = self.mask_sampler.mutate(old_masks, configs.mutate_prob)
                new_masks_str = masks2str(new_masks)

                if new_masks_str not in self.accuracy_dict_valid:
                    self.accuracy_dict_valid[new_masks_str] = self._eval_model(valid_data, new_masks)

                generated.append(new_masks)
                generated_acc.append(self.accuracy_dict_valid[new_masks_str])

            candidates = self.queue + generated
            candidates_acc = self.queue_acc + generated_acc
            order = np.argsort(candidates_acc)[::-1][:configs.num_samples]
            self.queue = [candidates[i] for i in order]
            self.queue_acc = [candidates_acc[i] for i in order]
            best_masks = self.queue[0]
            best_masks_str = masks2str(best_masks)

            if verbose or save_history:
                if best_masks_str not in self.accuracy_dict_test:
                    self.accuracy_dict_test[best_masks_str] = self._eval_model(test_data, best_masks)
                self.accuracy['search'].append(self.accuracy_dict_test[best_masks_str])

            if verbose:
                print('[Search][Epoch {}] Accuracy: {}'.format(epoch + 1, self.accuracy['search'][-1]))

            if epoch % configs.save_epoch == 0:
                if save_model:
                    self._save_search(path)
                    self.epoch['search'] = epoch + 1
                    self._save_epoch('search', path)

                if save_history:
                    self._save_accuracy('search', path)

        if save_model:
            self._save_search(path)
            self.epoch['search'] = configs.num_epochs
            self._save_epoch('search', path)

        if save_history:
            self._save_accuracy('search', path)


    def _finaltrain(self,
                    train_data,
                    test_data,
                    configs,
                    save_model=False,
                    save_history=False,
                    path='saved_models/default/final/',
                    verbose=False
                    ):

        if self.finalmodel is None:
            self.finalmodel_mask = self.queue[0]
            self.finalmodel = self.submodel(self.finalmodel_mask)
            self.finalmodel = nn.DataParallel(self.finalmodel).to(self.device)

        self.finalmodel.train()

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.finalmodel.parameters(), lr=configs.lr, momentum=configs.momentum, weight_decay=configs.weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=configs.lr_decay_epoch, gamma=configs.lr_decay)

        for epoch in range(self.epoch['final']):
            scheduler.step()

        for epoch in range(self.epoch['final'], configs.num_epochs):
            scheduler.step()

            for inputs, labels in train_data:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.finalmodel(inputs)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if verbose or save_history:
                self.accuracy['final'].append(self._eval_final(test_data))

            if verbose:
                print('[Final][Epoch {}] Accuracy: {}'.format(epoch + 1, self.accuracy['final'][-1]))

            if epoch % configs.save_epoch == 0:
                if save_model:
                    self._save_final(path)
                    self.epoch['final'] = epoch + 1
                    self._save_epoch('final', path)

                if save_history:
                    self._save_accuracy('final', path)

        if save_model:
            self._save_final(path)
            self.epoch['final'] = configs.num_epochs
            self._save_epoch('final', path)

        if save_history:
            self._save_accuracy('final', path)


    def eval(self, data):
        accuracy = self._eval_final(data)
        model_size = self.compute_model_size.compute(self.finalmodel_mask)

        return accuracy, model_size


    def _eval_model(self, data, masks):
        masks = self.mask_sampler.make_batch(masks)
        model = lambda x: self.model(x, masks)
        accuracy = self._eval(data, model)

        return accuracy


    def _eval_final(self, data):
        self.finalmodel.eval()
        accuracy = self._eval(data, self.finalmodel)
        self.finalmodel.train()

        return accuracy


    def _eval(self, data, model):
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in data:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = model(inputs)
                _, predict_labels = torch.max(outputs.detach(), 1)

                total += labels.size(0)
                correct += (predict_labels == labels).sum().item()

            return correct / total


    def _save_pretrain(self, path='saved_models/default/pretrain/'):
        if not os.path.isdir(path):
            os.makedirs(path)
        torch.save(self.model.state_dict(), os.path.join(path, 'model'))


    def _save_search(self, path='saved_models/default/search/'):
        if not os.path.isdir(path):
            os.makedirs(path)

        with open(os.path.join(path, 'accuracy_dict_valid.json'), 'w') as f:
            json.dump(self.accuracy_dict_valid, f)
        with open(os.path.join(path, 'accuracy_dict_test.json'), 'w') as f:
            json.dump(self.accuracy_dict_test, f)
        with open(os.path.join(path, 'queue.json'), 'w') as f:
            json.dump([masks.tolist() for masks in self.queue], f)
        with open(os.path.join(path, 'queue_acc.json'), 'w') as f:
            json.dump(self.queue_acc, f)


    def _save_final(self, path='saved_models/default/final/'):
        if not os.path.isdir(path):
            os.makedirs(path)

        with open(os.path.join(path, 'masks.json'), 'w') as f:
            json.dump(self.finalmodel_mask.tolist(), f)

        torch.save(self.finalmodel.state_dict(), os.path.join(path, 'model'))


    def _save_epoch(self, key, path='saved_models/default/'):
        if not os.path.isdir(path):
            os.makedirs(path)

        with open(os.path.join(path, 'last_epoch.json'), 'w') as f:
            json.dump(self.epoch[key], f)


    def _save_accuracy(self, key, path='saved_models/default/'):
        if not os.path.isdir(path):
            os.makedirs(path)
        filename = os.path.join(path, 'history.json')

        with open(filename, 'w') as f:
            json.dump(self.accuracy[key], f)


    def load(self, path='saved_models/default/'):
        self._load_pretrain(os.path.join(path, 'pretrain'))
        self._load_search(os.path.join(path, 'search'))
        self._load_final(os.path.join(path, 'final'))

        for key in ['pretrain', 'search', 'final']:
            self._load_epoch(key, os.path.join(path, key))
            self._load_accuracy(key, os.path.join(path, key))


    def _load_pretrain(self, path='saved_models/default/pretrain/'):
        try:
            filename = os.path.join(path, 'model')
            self.model.load_state_dict(torch.load(filename))

        except FileNotFoundError:
            pass


    def _load_search(self, path='saved_models/default/search/'):
        try:
            with open(os.path.join(path, 'accuracy_dict_valid.json')) as f:
                self.accuracy_dict_valid = json.load(f)
            with open(os.path.join(path, 'accuracy_dict_test.json')) as f:
                self.accuracy_dict_test = json.load(f)
            with open(os.path.join(path, 'queue.json')) as f:
                self.queue = json.load(f)
                self.queue = [torch.tensor(masks, dtype=torch.uint8) for masks in self.queue]
            with open(os.path.join(path, 'queue_acc.json')) as f:
                self.queue_acc = json.load(f)

        except FileNotFoundError:
            self.queue = []
            self.queue_acc = []


    def _load_final(self, path='saved_models/default/final/'):
        try:
            with open(os.path.join(path, 'masks.json'), 'r') as f:
                self.finalmodel_mask = json.load(f)
            self.finalmodel_mask = torch.tensor(self.finalmodel_mask, dtype=torch.uint8)
            self.finalmodel = self.submodel(self.finalmodel_mask)
            self.finalmodel = nn.DataParallel(self.finalmodel).to(self.device)

            filename = os.path.join(path, 'model')
            self.finalmodel.load_state_dict(torch.load(filename))

        except FileNotFoundError:
            pass


    def _load_epoch(self, key, path='saved_models/default/'):
        try:
            filename = os.path.join(path, 'last_epoch.json')
            with open(filename, 'r') as f:
                self.epoch[key] = json.load(f)

        except FileNotFoundError:
            self.epoch[key] = 0


    def _load_accuracy(self, key, path='saved_models/default/'):
        try:
            with open(os.path.join(path, 'history.json'), 'r') as f:
                self.accuracy[key] = json.load(f)

        except FileNotFoundError:
            pass