Exemplo n.º 1
0
def train(
    _net, _train_loader, _optimizer, _criterion, _device = 'cpu',
    _recorder: Recorder = None,
    _weight_quantization_error_collection = None, _input_quantization_error_collection = None,
    _weight_bit_allocation_collection = None, _input_bit_allocation_collection = None
):

    _net.train()
    _train_loss = 0
    _correct = 0
    _total = 0

    for batch_idx, (inputs, targets) in enumerate(_train_loader):

        inputs, targets = inputs.to(_device), targets.to(_device)

        _optimizer.zero_grad()
        outputs = _net(inputs)
        losses = _criterion(outputs, targets)
        losses.backward()
        _optimizer.step()

        _train_loss += losses.data.item()
        _, predicted = torch.max(outputs.data, 1)
        _total += targets.size(0)
        _correct += predicted.eq(targets.data).cpu().sum().item()

        progress_bar(
            batch_idx, len(_train_loader),
            'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (_train_loss / (batch_idx + 1), 100. * _correct / _total, _correct, _total)
        )

        if _recorder is not None:
            _recorder.update(loss=losses.data.item(), acc=[_correct / _total], batch_size=inputs.size(0), is_train=True)

        if _weight_quantization_error_collection and _input_quantization_error_collection is not None:
            for name, layer in _net.quantized_layer_collections.items():
                _weight_quantization_error = torch.abs(layer.quantized_weight - layer.pre_quantized_weight).mean().item()
                _input_quantization_error = torch.abs(layer.quantized_input - layer.pre_quantized_input).mean().item()
                _weight_quantization_error_collection[name].write('%.8e\n' % _weight_quantization_error)
                _input_quantization_error_collection[name].write('%.8e\n' % _input_quantization_error)

        if _weight_bit_allocation_collection and _input_bit_allocation_collection is not None:
            for name, layer in _net.quantized_layer_collections.items():
                _weight_bit_allocation_collection[name].write('%.2f\n' % (torch.abs(layer.quantized_weight_bit).mean().item()))
                _input_bit_allocation_collection[name].write('%.2f\n' % (torch.abs(layer.quantized_input_bit).mean().item()))

    return _train_loss / (len(_train_loader)), _correct / _total
Exemplo n.º 2
0
            _, predicted = torch.max(out_s.data, dim=1)
            correct_s += predicted.eq(y_s.data).cpu().sum().item()
            _, predicted = torch.max(out_t.data, dim=1)
            correct_t += predicted.eq(y_t.data).cpu().sum().item()

            loss_t += losses_t.item()

            total += y_s.size(0)

            progress_bar(batch_idx, min(len(source_loader), len(target_loader)), "[Training] Source acc: %.3f%% | Target acc: %.3f%%"
                             %(100.0 * correct_s / total, 100.0 * correct_t / total))

            #######################
            # Record Training log #
            #######################
            source_recorder.update(loss=losses_s.item(), acc=accuracy(out_s.data, y_s.data, (1, 5)),
                            batch_size=out_s.shape[0], cur_lr=optimizer_s.param_groups[0]['lr'], end=end)

            target_recorder.update(loss=losses_t.item(), acc=accuracy(out_t.data, y_t.data, (1, 5)),
                                   batch_size=out_t.shape[0], cur_lr=optimizer_t.param_groups[0]['lr'], end=end)

        # Test target acc
        test_acc = mask_test(target_net, target_mask_dict, target_test_loader)
        print('\n[Epoch %d] Test Acc: %.3f' % (epoch, test_acc))
        target_recorder.update(loss=None, acc=test_acc, batch_size=0, end=None, is_train=False)

        if best_test_acc < test_acc:
            best_test_acc = test_acc
            if not os.path.isdir('%s/checkpoint' %save_root):
                os.makedirs('%s/checkpoint' %save_root)
            torch.save(source_net.state_dict(), '%s/checkpoint/%s-temp.pth' %(save_root, source_dataset_name))
            torch.save(target_net.state_dict(), '%s/checkpoint/%s-temp.pth' %(save_root, target_dataset_name))
Exemplo n.º 3
0
                layer_idx = layer_info[1]
                layer = get_layer(net, layer_idx)
                layer.weight.grad.data = (layer.calibration *
                                          layer.pre_quantized_grads)
                # layer.weight.grad.data.copy_(layer.calibration * meta_grad_dict[layer_name][1].data)

        # Get refine gradients for next computation
        optimizee.get_refine_gradient()

        # These gradient should be saved in next iteration's inference
        if len(meta_grad_dict) != 0:
            update_parameters(net, lr=optimizee.param_groups[0]['lr'])

        recorder.update(loss=losses.data.item(),
                        acc=accuracy(outputs.data, targets.data, (1, 5)),
                        batch_size=outputs.shape[0],
                        cur_lr=optimizee.param_groups[0]['lr'],
                        end=end)

        recorder.print_training_result(batch_idx, len(train_loader))
        end = time.time()

    test_acc = test(net,
                    quantized_type=quantized_type,
                    test_loader=test_loader,
                    dataset_name=dataset_name,
                    n_batches_used=None)
    recorder.update(loss=None,
                    acc=test_acc,
                    batch_size=0,
                    end=None,
Exemplo n.º 4
0
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        model.zero_grad()
        # global_step += 1

        # ------
        # Record
        # ------
        preds = logits.data.cpu().numpy()
        preds = np.argmax(preds, axis=1)
        out_label_ids = inputs["labels"].data.cpu().numpy()
        result = glue_compute_metrics(
            task_name, preds, out_label_ids)  # ['acc', 'f1', 'acc_and_f1']
        if recorder is not None:
            recorder.update(losses.item(),
                            acc=[result['acc_and_f1']],
                            batch_size=args.train_batch_size,
                            is_train=True)
            recorder.print_training_result(batch_idx=step,
                                           n_batch=len(train_dataloader))
        else:
            train_loss += losses.item()
            progress_bar(step, len(train_dataloader),
                         "Loss: %.3f" % (train_loss / (step + 1)))

    result = evaluate(task_name, model, eval_dataloader, model_type)
    print(result)
    if recorder is not None:
        recorder.update(acc=result['acc_and_f1'], is_train=False)

if recorder is not None:
    recorder.close()
Exemplo n.º 5
0
class Task():

    def __init__(self, task_name, task_type = 'prune', optimizer_type = 'adam',
                 save_root = None, SummaryPath = None, use_cuda = True, **kwargs):

        self.task_name = task_name
        self.task_type = task_type # prune, soft-quantize
        self.model_name, self.dataset_name = task_name.split('-')
        self.ratio = 'sample' if self.dataset_name in ['CIFARS'] else -1

        #######
        # Net #
        #######
        if task_type == 'prune':
            if self.model_name == 'ResNet20':
                if self.dataset_name in ['CIFAR10', 'CIFARS']:
                    self.net = resnet20_cifar()
                elif self.dataset_name == 'STL10':
                    self.net = resnet20_stl()
                else:
                    raise NotImplementedError
            elif self.model_name == 'ResNet32':
                if self.dataset_name in ['CIFAR10', 'CIFARS']:
                    self.net = resnet32_cifar()
                elif self.dataset_name == 'STL10':
                    self.net = resnet32_stl()
                else:
                    raise NotImplementedError
            elif self.model_name == 'ResNet56':
                if self.dataset_name in ['CIFAR10', 'CIFARS']:
                    self.net = resnet56_cifar()
                elif self.dataset_name == 'CIFAR100':
                    self.net = resnet56_cifar(num_classes=100)
                elif self.dataset_name == 'STL10':
                    self.net = resnet56_stl()
                else:
                    raise NotImplementedError
            elif self.model_name == 'ResNet18':
                if self.dataset_name == 'ImageNet':
                    self.net = resnet18()
                else:
                    raise NotImplementedError
            elif self.model_name == 'vgg11':
                self.net = vgg11() if self.dataset_name == 'CIFAR10' else vgg11_stl10()
            else:
                print(self.model_name, self.dataset_name)
                raise NotImplementedError
        elif task_type == 'soft-quantize':
            if self.model_name == 'ResNet20':
                if self.dataset_name in ['CIFAR10', 'CIFARS']:
                    self.net = soft_quantized_resnet20_cifar()
                elif self.dataset_name in ['STL10']:
                    self.net = soft_quantized_resnet20_stl()
            else:
                raise NotImplementedError
        else:
            raise ('Task type not defined.')


        self.meta_opt_flag = True # True for enabling meta leraning

        ##############
        # Meta Prune #
        ##############
        self.mask_dict = dict()
        self.meta_grad_dict = dict()
        self.meta_hidden_state_dict = dict()

        ######################
        # Meta Soft Quantize #
        ######################
        self.quantized = 0 # Quantized type
        self.alpha_dict = dict()
        self.alpha_hidden_dict = dict()
        self.sq_rate = 0
        self.s_rate = 0
        self.q_rate = 0

        ##########
        # Record #
        ##########
        self.dataset_type = 'large' if self.dataset_name in ['ImageNet'] else 'small'
        self.SummaryPath = SummaryPath
        self.save_root = save_root

        self.recorder = Recorder(self.SummaryPath, self.dataset_name, self.task_name)

        ####################
        # Load Pre-trained #
        ####################
        self.pretrain_path = '%s/%s-pretrain.pth' %(self.save_root, self.task_name)
        self.net.load_state_dict(torch.load(self.pretrain_path))
        print('Load pre-trained model from %s' %self.pretrain_path)

        if use_cuda:
            self.net.cuda()

        # Optimizer for this task
        if optimizer_type in ['Adam', 'adam']:
            self.optimizer = Adam(self.net.parameters(), lr=1e-3)
        else:
            self.optimizer = SGD(self.net.parameters())

        if self.dataset_name == 'ImageNet':
            try:
                self.train_loader = get_lmdb_imagenet('train', 128)
                self.test_loader = get_lmdb_imagenet('test', 100)
            except:
                self.train_loader = get_dataloader(self.dataset_name, 'train', 128)
                self.test_loader = get_dataloader(self.dataset_name, 'test', 100)
        else:
            self.train_loader = get_dataloader(self.dataset_name, 'train', 128, ratio=self.ratio)
            self.test_loader = get_dataloader(self.dataset_name, 'test', 128)

        self.iter_train_loader = yielder(self.train_loader)
        # For shared
        # self.loss = 0
        # self.niter = 0 # Overall iteration record
        # self.test_loss = 0
        # self.smallest_training_loss = 1e9
        # self.stop = False # Whether to stop training
        #
        # # For CIFAR dataset
        # # self.train_acc = AverageMeter()
        # self.total = 0 # Number of batches used in training
        # self.n_batch = 0 # Number of batches used in training
        # self.test_acc = 0
        # self.best_test_acc = 0
        # self.ascend_count = 0
        #
        # # For ImageNet dataset
        # # self.loss = AverageMeter()
        # self.top1 = AverageMeter()
        # self.top5 = AverageMeter()
        # self.batch_time = AverageMeter()
        # self.data_time = AverageMeter()
        # self.test_acc_top1 = 0
        # self.test_acc_top5 = 0
        # self.best_test_acc_top1 = 0
        # self.best_test_acc_top5 = 0
        #
        # #######################
        # # Parameters for Meta #
        # #######################
        # self.mask_dict = dict()
        # self.meta_grad_dict = dict()
        # self.meta_hidden_state_dict = dict()
        #
        # ###########################
        # # Open File for Recording #
        # ###########################
        # if self.dataset_type == 'small':
        #     self.loss_record = open('%s/%s-loss.txt' %(self.SummaryPath, self.task_name), 'w+')
        #     self.train_acc_record = open('%s/%s-train-acc.txt' %(self.SummaryPath, self.task_name), 'w+')
        #     self.test_acc_record = open('%s/%s-test-acc.txt' %(self.SummaryPath, self.task_name), 'w+')
        #     self.lr_record = open('%s/%s-lr.txt' %(self.SummaryPath, self.task_name), 'w+')
        #     # print('Initialize %s' %(self.task_name))
        # else:
        #     self.loss_record = open('%s/%s-loss.txt' % (self.SummaryPath, self.task_name), 'w+')
        #     self.train_top1_acc_record = open('%s/%s-train-top1-acc.txt' % (self.SummaryPath, self.task_name), 'w+')
        #     self.train_top5_acc_record = open('%s/%s-train-top5-acc.txt' % (self.SummaryPath, self.task_name), 'w+')
        #     self.test_top1_acc_record = open('%s/%s-test-top1-acc.txt' % (self.SummaryPath, self.task_name), 'w+')
        #     self.test_top5_acc_record = open('%s/%s-test-top5-acc.txt' % (self.SummaryPath, self.task_name), 'w+')
        #     self.lr_record = open('%s/%s-lr.txt' % (self.SummaryPath, self.task_name), 'w+')

    def train(self):
        self.net.train()

    def eval(self):
        self.net.eval()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.optimizer.step()

    def update_record_performance(self, loss, acc, batch_size=0, lr = 1e-3, end=None, is_train = True):

        self.recorder.update(loss=loss, acc=acc, batch_size=batch_size, cur_lr=lr, end=end, is_train=is_train)

        # if is_train:
        #
        #     self.loss += loss
        #     self.n_batch += 1
        #     self.total += batch_size
        #     self.niter += 1
        #
        #     if self.dataset_type == 'small':
        #         self.top1.update(acc[0], batch_size)
        #
        #         self.loss_record.write('%d, %.8f\n' % (self.niter, self.loss / self.n_batch))
        #         self.train_acc_record.write('%d, %.3f\n' % (self.niter, self.top1.avg))
        #         self.lr_record.write('%d, %e\n' % (self.niter, self.optimizer.param_groups[0]['lr']))
        #
        #         self.flush([self.loss_record, self.train_acc_record, self.lr_record])
        #
        #     else:
        #         self.batch_time.update(time.time() - end)
        #         self.top1.update(acc[0], batch_size)
        #         self.top5.update(acc[1], batch_size)
        #
        #         self.loss_record.write('%d, %.8f\n' % (self.niter, self.loss / self.n_batch))
        #         self.train_top1_acc_record.write('%d, %.3f\n' % (self.niter, self.top1.avg))
        #         self.train_top5_acc_record.write('%d, %.3f\n' % (self.niter, self.top5.avg))
        #         self.lr_record.write('%d, %ef\n' % (self.niter, self.optimizer.param_groups[0]['lr']))
        #
        #         self.flush([self.loss_record, self.train_top1_acc_record, self.train_top5_acc_record, self.lr_record])
        #
        # else:
        #     self.test_loss = loss
        #
        #     if self.dataset_type == 'small':
        #
        #         self.test_acc = acc
        #
        #         if self.best_test_acc < self.test_acc:
        #             self.best_test_acc = self.test_acc
        #             print('[%s] Best test acc' %self.task_name)
        #             # self.save(self.SummaryPath)
        #
        #         self.test_acc_record.write('%d, %.3f\n' % (self.niter, self.test_acc))
        #         self.flush([self.test_acc_record])
        #
        #     else:
        #
        #         self.test_acc_top1, self.test_acc_top5 = acc[0], acc[1]
        #
        #         if self.best_test_acc_top1 < self.test_acc_top1 or self.best_test_acc_top5 < self.test_acc_top5:
        #             self.best_test_acc_top1 = self.test_acc_top1
        #             self.best_test_acc_top5 = self.test_acc_top5
        #             print('[%s] Best test acc' % self.task_name)
        #             # self.save(self.SummaryPath)
        #
        #         self.test_top1_acc_record.write('%d, %.3f\n' % (self.niter, self.test_acc_top1))
        #         self.test_top5_acc_record.write('%d, %.3f\n' % (self.niter, self.test_acc_top5))
        #
        #         self.flush([self.test_top1_acc_record, self.test_top5_acc_record])


    def reset_performance(self):

        # self.loss = 0
        #
        # if self.dataset_type == 'small':
        #     self.loss = 0
        #     # self.train_acc.reset()
        #     self.top1.reset()
        #     self.total = 0
        #     self.n_batch = 0
        # else:
        #     self.best_test_acc_top1 = 0
        #     self.best_test_acc_top5 = 0
        #     self.top1.reset()
        #     self.top5.reset()
        #     self.batch_time.reset()
        self.recorder.reset_performance()


    # def set_best_acc(self, test_acc):
    #     self.best_test_acc = test_acc


    def save(self, save_root):
        torch.save(self.net.state_dict(), '%s/%s-net.pth' %(save_root, self.task_name))

    def get_best_test_acc(self):

        # if self.dataset_type == 'small':
        #     return self.best_test_acc
        # else:
        #     return self.best_test_acc_top1, self.best_test_acc_top5
        return self.recorder.get_best_test_acc()

    def flush(self, file_list=None):

        for file in file_list:
            file.flush()

    def close(self):

        # if self.dataset_type == 'small':
        #     self.loss_record.close()
        #     self.train_acc_record.close()
        #     self.test_acc_record.close()
        #     self.lr_record.close()
        # else:
        #     self.loss_record.close()
        #     self.train_top1_acc_record.close()
        #     self.train_top5_acc_record.close()
        #     self.test_top1_acc_record.close()
        #     self.test_top5_acc_record.close()
        #     self.lr_record.close()
        self.recorder.close()

    def adjust_lr(self, adjust_type):

        # if self.dataset_type == 'small':
        #     if self.loss > self.smallest_training_loss:
        #         self.ascend_count += 1
        #     else:
        #         self.smallest_training_loss = self.loss
        #         self.ascend_count = 0
        #
        #     if self.ascend_count >= 3:
        #         self.ascend_count = 0
        #         self.optimizer.param_groups[0]['lr'] *= 0.1
        #         if self.optimizer.param_groups[0]['lr'] < 1e-6:
        #             self.stop = True
        #
        #     print('[%s] Current training loss: %.3f[%.3f], ascend count: %d'
        #           %(self.task_name, self.loss, self.smallest_training_loss, self.ascend_count))
        #     print('---------------------------------------------------')
        # else:
        #     raise NotImplementedError

        self.recorder.adjust_lr(self.optimizer)
Exemplo n.º 6
0
        # output = [(trg len - 1) * batch size, output dim]

        losses = criterion(output, trg)

        losses.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()

        # ------
        # Record
        # ------
        if recoder is not None:
            recoder.update(losses.item(),
                           batch_size=args.batch_size,
                           cur_lr=optimizer.param_groups[0]['lr'])
            recoder.print_training_result(batch_idx, len(train_loader))
        else:
            train_loss += losses.item()
            progress_bar(batch_idx, len(train_loader),
                         "Loss: %.3f" % (train_loss / (batch_idx + 1)))

    # -----
    # Test
    # -----
    eval_loss = evaluate(model, test_loader, criterion)
    if recoder is not None:
        recoder.update(eval_loss, is_train=False)
    print('[%2d] Test loss: %.3f' % (epoch_idx, eval_loss))