Beispiel #1
0
    train_loader = DataLoader(get_dataset(args.dataset, "train"),
                              shuffle=True,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              pin_memory=False)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1e-4,
                          nesterov=True)
    annealer = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs)

    loss_meter = meter.AverageValueMeter()
    time_meter = meter.TimeMeter(unit=False)

    noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset))

    train_losses = []

    for epoch in range(args.num_epochs):

        for i, (x, y) in enumerate(train_loader):

            x, y = x.to(args.device), y.to(args.device)

            if args.adversarial:
                eps = min(args.eps, epoch * args.eps / (args.num_epochs // 2))
                x, loss = pgd_attack_smooth(model, x, y, eps, noise, sample_size=4, adv=args.adv)
            elif args.stability:
def main():
    if not os.path.exists(opt.save):
        os.mkdir(opt.save)

    if opt.scat > 0:
        model, params, stats = models.__dict__[opt.model](N=opt.N, J=opt.scat)
    else:
        model, params, stats = models.__dict__[opt.model]()

    def create_optimizer(opt, lr):
        print('creating optimizer with lr = %f' % lr)
        return torch.optim.SGD(params.values(),
                               lr,
                               opt.momentum,
                               weight_decay=opt.weightDecay)

    def get_iterator(mode):
        ds = create_dataset(opt, mode)
        return ds.parallel(batch_size=opt.batchSize,
                           shuffle=mode,
                           num_workers=opt.nthread,
                           pin_memory=False)

    optimizer = create_optimizer(opt, opt.lr)

    iter_test = get_iterator(False)
    iter_train = get_iterator(True)

    if opt.scat > 0:
        scat = Scattering(M=opt.N, N=opt.N, J=opt.scat, pre_pad=False).cuda()

    epoch = 0
    if opt.resume != '':
        resumeFile = opt.resume
        if not resumeFile.endswith('pt7'):
            resumeFile = torch.load(opt.resume + '/latest.pt7')['latest_file']
            state_dict = torch.load(resumeFile)
            epoch = state_dict['epoch']
            params_tensors, stats = state_dict['params'], state_dict['stats']
            for k, v in params.iteritems():
                v.data.copy_(params_tensors[k])
            optimizer.load_state_dict(state_dict['optimizer'])
            print('model was restored from epoch:', epoch)

    print('\nParameters:')
    print(
        pd.DataFrame([(key, v.size(), torch.typename(v.data))
                      for key, v in params.items()]))
    print('\nAdditional buffers:')
    print(
        pd.DataFrame([(key, v.size(), torch.typename(v))
                      for key, v in stats.items()]))
    n_parameters = sum(
        [p.numel() for p in list(params.values()) + list(stats.values())])
    print('\nTotal number of parameters: %f' % n_parameters)

    meter_loss = meter.AverageValueMeter()
    classacc = meter.ClassErrorMeter(topk=[1, 5], accuracy=False)
    timer_data = meter.TimeMeter('s')
    timer_sample = meter.TimeMeter('s')
    timer_train = meter.TimeMeter('s')
    timer_test = meter.TimeMeter('s')

    def h(sample):
        inputs = sample[0].cuda()
        if opt.scat > 0:
            inputs = scat(inputs)
        inputs = Variable(inputs)
        targets = Variable(sample[1].cuda().long())
        if sample[2]:
            model.train()
        else:
            model.eval()
        y = torch.nn.parallel.data_parallel(model, inputs,
                                            np.arange(opt.ngpu).tolist())
        return F.cross_entropy(y, targets), y

    def log(t, state):
        if (t['epoch'] > 0 and t['epoch'] % opt.frequency_save == 0):
            torch.save(
                dict(params={k: v.data.cpu()
                             for k, v in params.iteritems()},
                     stats=stats,
                     optimizer=state['optimizer'].state_dict(),
                     epoch=t['epoch']),
                open(os.path.join(opt.save, 'epoch_%i_model.pt7' % t['epoch']),
                     'w'))
            torch.save(
                dict(
                    latest_file=os.path.join(opt.save, 'epoch_%i_model.pt7' %
                                             t['epoch'])),
                open(os.path.join(opt.save, 'latest.pt7'), 'w'))

        z = vars(opt).copy()
        z.update(t)
        logname = os.path.join(opt.save, 'log.txt')
        with open(logname, 'a') as f:
            f.write('json_stats: ' + json.dumps(z) + '\n')
        print(z)

    def on_sample(state):
        global data_time
        data_time = timer_data.value()
        timer_sample.reset()
        state['sample'].append(state['train'])

    def on_forward(state):
        prev_sum5 = classacc.sum[5]
        prev_sum1 = classacc.sum[1]
        classacc.add(state['output'].data,
                     torch.LongTensor(state['sample'][1]))
        meter_loss.add(state['loss'].data[0])

        next_sum5 = classacc.sum[5]
        next_sum1 = classacc.sum[1]
        n = state['output'].data.size(0)
        curr_top5 = 100.0 * (next_sum5 - prev_sum5) / n
        curr_top1 = 100.0 * (next_sum1 - prev_sum1) / n
        sample_time = timer_sample.value()
        timer_data.reset()
        if (state['train']):
            txt = 'Train:'
        else:
            txt = 'Test'
        if (state['t'] % opt.frequency_print == 0 and state['t'] > 0):
            print(
                '%s [%i,%i/%i] ; loss: %.3f (%.3f) ; acc5: %.2f (%.2f) ; acc1: %.2f (%.2f) ; data %.3f ; time %.3f'
                % (txt, state['epoch'], state['t'] % len(state['iterator']),
                   len(state['iterator']), state['loss'].data[0],
                   meter_loss.value()[0], curr_top5, classacc.value(5),
                   curr_top1, classacc.value(1), data_time, sample_time))

    def on_start(state):
        state['epoch'] = epoch

    def on_start_epoch(state):
        classacc.reset()
        meter_loss.reset()
        timer_train.reset()

        state['iterator'] = iter_train

        epoch = state['epoch'] + 1
        if epoch in epoch_step:
            print('changing LR')
            lr = state['optimizer'].param_groups[0]['lr']
            state['optimizer'] = create_optimizer(opt, lr * opt.lr_decay_ratio)

    def on_end_epoch(state):
        if (state['t'] % opt.frequency_test == 0 and state['t'] > 0):
            train_loss = meter_loss.value()
            train_acc = classacc.value()
            train_time = timer_train.value()
            meter_loss.reset()
            classacc.reset()
            timer_test.reset()

            engine.test(h, iter_test)

            log(
                {
                    "train_loss": train_loss[0],
                    "train_acc": 100 - train_acc[0],
                    "test_loss": meter_loss.value()[0],
                    "test_acc": 100 - classacc.value()[0],
                    "epoch": state['epoch'],
                    "n_parameters": n_parameters,
                    "train_time": train_time,
                    "test_time": timer_test.value(),
                }, state)

    engine = Engine()
    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.hooks['on_start'] = on_start
    engine.train(h, iter_train, opt.epochs, optimizer)
Beispiel #3
0
def main():
    model, params, stats = models.__dict__[opt.model](N=opt.N, J=opt.scat)

    iter_test = get_iterator(False, opt)

    scat = Scattering(M=opt.N, N=opt.N, J=opt.scat, pre_pad=False).cuda()

    epoch = 0
    if opt.resume != '':
        resumeFile = opt.resume
        if not resumeFile.endswith('pt7'):
            resumeFile = torch.load(opt.resume + '/latest.pt7')['latest_file']
        state_dict = torch.load(resumeFile)

        model.load_state_dict(state_dict['state_dict'])
        print('model was restored from epoch:', epoch)

    print('\nParameters:')
    print(
        pd.DataFrame([(key, v.size(), torch.typename(v.data))
                      for key, v in params.items()]))
    print('\nAdditional buffers:')
    print(
        pd.DataFrame([(key, v.size(), torch.typename(v))
                      for key, v in stats.items()]))
    n_parameters = sum(
        [p.numel() for p in list(params.values()) + list(stats.values())])
    print('\nTotal number of parameters: %f' % n_parameters)

    meter_loss = meter.AverageValueMeter()
    classacc = meter.ClassErrorMeter(topk=[1, 5], accuracy=False)
    timer_data = meter.TimeMeter('s')
    timer_sample = meter.TimeMeter('s')
    timer_train = meter.TimeMeter('s')
    timer_test = meter.TimeMeter('s')

    def h(sample):
        inputs = sample[0].cuda()
        if opt.scat > 0:
            inputs = scat(inputs)
        inputs = Variable(inputs)
        targets = Variable(sample[1].cuda().long())
        if sample[2]:
            model.train()
        else:
            model.eval()

    # y = model.forward(inputs)
        y = torch.nn.parallel.data_parallel(model, inputs,
                                            np.arange(opt.ngpu).tolist())
        return F.cross_entropy(y, targets), y

    def on_sample(state):
        global data_time
        data_time = timer_data.value()
        timer_sample.reset()
        state['sample'].append(state['train'])

    def on_forward(state):
        prev_sum5 = classacc.sum[5]
        prev_sum1 = classacc.sum[1]
        classacc.add(state['output'].data,
                     torch.LongTensor(state['sample'][1]))
        meter_loss.add(state['loss'].data[0])

        next_sum5 = classacc.sum[5]
        next_sum1 = classacc.sum[1]
        n = state['output'].data.size(0)
        curr_top5 = 100.0 * (next_sum5 - prev_sum5) / n
        curr_top1 = 100.0 * (next_sum1 - prev_sum1) / n
        sample_time = timer_sample.value()
        timer_data.reset()
        if (state['train']):
            txt = 'Train:'
        else:
            txt = 'Test'

        print(
            '%s [%i,%i/%i] ; loss: %.3f (%.3f) ; err5: %.2f (%.2f) ; err1: %.2f (%.2f) ; data %.3f ; time %.3f'
            % (txt, state['epoch'], state['t'] % len(state['iterator']),
               len(state['iterator']), state['loss'].data[0],
               meter_loss.value()[0], curr_top5, classacc.value(5), curr_top1,
               classacc.value(1), data_time, sample_time))

    def on_start(state):
        state['epoch'] = epoch

    def on_start_epoch(state):
        classacc.reset()
        meter_loss.reset()
        timer_train.reset()

        epoch = state['epoch'] + 1

    def on_end_epoch(state):
        train_loss = meter_loss.value()
        train_acc = classacc.value()
        train_time = timer_train.value()
        meter_loss.reset()
        classacc.reset()
        timer_test.reset()

        engine.test(h, iter_test)

    engine = Engine()
    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.hooks['on_start'] = on_start
    engine.test(h, iter_test)
    print(classacc.value())
Beispiel #4
0
def construct_engine(engine_args, num_classes, checkpoint_iter_freq=5000, checkpoint_epoch_freq=1,
                     checkpoint_save_path='checkpoints',
                     iter_log_freq=30, environment='main', lr_points=[]):
    engine = Engine(**engine_args)

    # ***************************Meter Setting******************************

    class Meterhelper(object):
        # titles: dict for {key: title} pair

        def __init__(self, meter, titles, plot_type='line'):
            self.meter = meter

            assert type(titles) is dict
            self.loggers = dict()
            for key in titles:
                self.loggers[key] = VisdomPlotLogger(plot_type, opts={'title': titles[key]}, env=environment)

        def log(self, key, x, y_arg=None):
            assert key in self.loggers.keys()
            if y_arg is None:
                y = self.meter.value()
            else:
                y = self.meter.value(y_arg)
            if type(y) is tuple:
                y = y[0]
            self.loggers[key].log(x, y)

        def add(self, *arg, **args):
            return self.meter.add(*arg, **args)

        def reset(self):
            return self.meter.reset()

    class SegmentationHelper(Meterhelper):
        def __init__(self):
            super(SegmentationHelper, self).__init__(meter.ConfusionMeter(num_classes),
                                                     dict(miu='Mean IoU', pacc='Pixel Accuracy', macc='Mean Accuracy',
                                                          fwiu='f.w.Iou'))
            self.ignore_lbl = engine_args['validate_iterator'].dataset.ignore_lbl

        def log(self, x):
            confusion_matrix = self.meter.value()
            values = utilities.segmentation_meter.compute_segmentation_meters(confusion_matrix)

            for key in values:
                self.loggers[key].log(x, values[key])

        def add(self, opt, target):
            opt, target = utilities.segmentation_meter.preprocess_for_confusion(opt, target, self.ignore_lbl)
            self.meter.add(opt, target)


    time_meter = meter.TimeMeter(1)

    windowsize = 100
    meters = dict(
        data_loading_meter=Meterhelper(meter.MovingAverageValueMeter(windowsize=windowsize),
                                       dict(data_t='Data Loading Time')),
        gpu_time_meter=Meterhelper(meter.MovingAverageValueMeter(windowsize=windowsize),
                                   dict(gpu_t='Gpu Computing Time')),
        train_loss_meter=Meterhelper(meter.MovingAverageValueMeter(windowsize=windowsize),
                                     dict(train_loss_iteration='Training Loss(Iteration)',
                                          train_loss_epoch='Training Loss(Epoch)')),
        test_loss_meter=Meterhelper(meter.AverageValueMeter(), dict(test_loss='Test Loss')),
        segmentation_meter=SegmentationHelper())

    # ***************************Auxiliaries******************************

    def reset_meters():
        time_meter.reset()
        for key in meters:
            meters[key].reset()

    def prepare_network(state):
        # switch model
        if state['train']:
            state['network'].train()
        else:
            state['network'].eval()

    def wrap_data(state):
        if state['gpu_ids'] is not None:
            # state['sample'][0] = state['sample'][0].cuda(device=state['gpu_ids'][0], async=False)
            state['sample'][1] = state['sample'][1].cuda(device=state['gpu_ids'][0], async=True)

        volatile = False

        if not state['train']:
            volatile = True
        state['sample'][0] = Variable(data=state['sample'][0], volatile=volatile)
        state['sample'][1] = Variable(data=state['sample'][1], volatile=volatile)

    def save_model(state, filename):
        model = state['network']
        torch.save({'model': copy.deepcopy(model).cpu().state_dict(), 'optimizer': state['optimizer'].state_dict()},
                   filename)
        print('==>Model {} saved.'.format(filename))

    def adjust_learning_rate(state):
        optimizer = state['optimizer']
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1

        print('~~~~~~~~~~~~~~~~~~adjust learning rate~~~~~~~~~~~~~~~~~~~~')

    # ***************************Callback Setting******************************

    def on_start(state):
        # wrap network
        if state['gpu_ids'] is None:
            print('Training/Validating without gpus ...')
        else:
            if not torch.cuda.is_available():
                raise RuntimeError('Cuda is not available')

            state['network'].cuda(state['gpu_ids'][0])
            print('Training/Validating on gpu: {}'.format(state['gpu_ids']))

        if state['train']:
            print('*********************Start Training at {}***********************'.format(time.strftime('%c')))
            if state['t'] == 0:
                filename = os.path.join(checkpoint_save_path, 'init_model.pth.tar')
                save_model(state, filename)
            max_iter = len(state['train_iterator']) * state['maxepoch']
            poly_lambda = lambda iteration: (1 - iteration / max_iter) ** 0.9
            state['scheduler'] = torch.optim.lr_scheduler.LambdaLR(state['optimizer'], poly_lambda)
        else:
            print('-------------Start Validation at {} For Epoch{}--------------'.format(time.strftime('%c'),
                                                                                         state['epoch']))
        prepare_network(state)
        reset_meters()

    def on_start_epoch(state):
        # change state of the network
        reset_meters()
        print('--------------Start Training at {} for Epoch{}-----------------'.format(time.strftime('%c'),
                                                                                       state['epoch']))
        time_meter.reset()
        prepare_network(state)

    def on_end_sample(state):
        # wrap data
        state['sample'].append(state['train'])
        wrap_data(state)
        meters['data_loading_meter'].add(time_meter.value())

    def on_start_forward(state):
        # timing
        time_meter.reset()

    def on_end_forward(state):
        # loss meters
        if state['train']:
            meters['train_loss_meter'].add(state['loss'].data[0])
            state['scheduler'].step(state['t'])
        else:
            meters['test_loss_meter'].add(state['loss'].data[0])
            meters['segmentation_meter'].add(state['output'], state['sample'][1])

    def on_end_update(state):
        # logging info and saving model
        meters['gpu_time_meter'].add(time_meter.value())
        if state['t'] % iter_log_freq == 0 and state['t'] != 0:
            meters['data_loading_meter'].log('data_t', x=state['t'])
            meters['gpu_time_meter'].log('gpu_t', x=state['t'])
            meters['train_loss_meter'].log('train_loss_iteration', x=state['t'])

        if checkpoint_iter_freq and state['t'] % checkpoint_iter_freq == 0:
            filename = os.path.join(checkpoint_save_path,
                                    'e' + str(state['epoch']) + 't' + str(state['t']) + '.pth.tar')
            save_model(state, filename)
        time_meter.reset()

    def on_end_epoch(state):
        # logging info and saving model

        meters['train_loss_meter'].log('train_loss_epoch', x=state['epoch'])
        print('***************Epoch {} done: loss {}*****************'.format(state['epoch'],
                                                                              meters['train_loss_meter'].meter.value()))
        if checkpoint_epoch_freq and state['epoch'] % checkpoint_epoch_freq == 0:
            filename = os.path.join(checkpoint_save_path,
                                    'e' + str(state['epoch']) + 't' + str(state['t']) + '.pth.tar')
            save_model(state, filename)

        # adjust learning rate
        if state['epoch'] in lr_points:
            adjust_learning_rate(state)

        reset_meters()

        # do validation at the end of epoch
        state['train'] = False
        engine.validate()
        state['train'] = True

    def on_end_test(state):
        # calculation
        meters['test_loss_meter'].log('test_loss', x=state['epoch'])
        meters['segmentation_meter'].log(x=state['epoch'])
        print('----------------Test epoch {} done: loss {}------------------'.format(state['epoch'], meters[
            'test_loss_meter'].meter.value()))
        reset_meters()

    def on_end(state):
        # logging
        t = time.strftime('%c')
        if state['train']:
            print('*********************Training done at {}***********************'.format(t))
        else:
            print('*********************Validation done at {}***********************'.format(t))

    engine.hooks['on_start'] = on_start
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_sample'] = on_end_sample
    engine.hooks['on_start_forward'] = on_start_forward
    engine.hooks['on_end_forward'] = on_end_forward
    engine.hooks['on_end_update'] = on_end_update
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.hooks['on_end_test'] = on_end_test
    engine.hooks['on_end'] = on_end

    return engine
Beispiel #5
0
def train(args, model, dataloaders, criterion, optimizer, scheduler, logger, epochs=25, is_inception=False):
    """
    args: 从键盘接收的参数
    model: 将被训练的模型
    dataloaders: 数据加载器
    criterion: 损失函数
    optimizer: 训练时的优化器
    scheduler: 学习率调整机制
    logger: 日志
    epochs: 训练周期数
    is_inception: 是否为inception模型的标志
    """

    # 训练周期数
    epochs = epochs or args.epochs

    # 模型保存地址
    if args.pretrained and args.feature:
        mode = "feature_extractor" # pretrained=True, feature=True
    elif args.pretrained and not args.feature:
        mode = "fine_tuning" # pretrained=True, feature=False
    else:
        mode = "from_scratch" # pretrained=False, feature=False
    # 模型保存地址
    model_path = Path(args.output) / args.arch / mode / "model.pt"
    # 准确率最好的模型保存地址
    best_modelpath = Path(args.output) / args.arch / mode / "bestmodel.pt"

    # 断点训练
    if (model_path.exists()):
        state = torch.load(str(model_path))

        epoch = state["epoch"]
        model.load_state_dict(state["model"])
        best_acc = state["best_acc"]

        logger.info("Loading epoch {} checkpoint ...".format(epoch))
        print("Restored model, epoch {}".format(epoch))
    else:
        epoch = 0
        best_acc = float('inf')

    # save匿名函数,使用的时候就调用save(ep)
    save = lambda epoch: torch.save({
        "model":model.state_dict(),
        "epoch":epoch,
        "best_acc": best_acc,
        }, str(model_path))

    # 训练指标
    running_loss_meter = meter.AverageValueMeter() # 平均值loss
    # running_acc_meter = meter.mAPMeter() # 所有类的平均正确率
    running_acc_meter = meter.ClassErrorMeter(topk=[1], accuracy=True) # 准确率
    time_meter = meter.TimeMeter(unit=True)  # 测量训练时间

    # 结果记录文件
    resultpath = Path(args.output) / args.arch / mode / "train_result.pkl"
    result_writer = ResultsWriter(resultpath, overwrite=False)

    for epoch in range(epoch, epochs):
        print("Epoch {}/{}".format(epoch, epochs-1))
        print("-" * 10)

        # 每个epoch都有一个训练和验证阶段
        for phase in ["train", "val"]:
            if phase == "train":
                model.train() # Set model to training mode
            else:
                model.eval() # Set model to evaluate mode

            # 每个epoch的train和val阶段分别重置
            running_loss_meter.reset()
            running_acc_meter.reset()

            random.seed(args.seed)
            tq = tqdm.tqdm(total=len(dataloaders[phase].datasets))
            tq.set_description("{} for Epoch {}/{}".format(phase, epoch+1, epochs))

            try:
                # 迭代数据
                for inputs, labels in dataloaders[phase]:
                    # 将输入和标签放入gpu或者cpu中
                    inputs = inputs.cuda() if torch.cuda.is_available() else inputs
                    labels = labels.cuda() if torch.cuda.is_available() else labels

                    # 零参数梯度
                    optimizer.zero_grad()

                    # 前向
                    # track history if only in train
                    with torch.set_grad_enabled(phase=="train"):
                        # inception的训练和验证有区别
                        if is_inception and phase == "train":
                            outputs, aux_outputs = model(inpus)
                            loss1 = criterion(outputs, labels)
                            loss2 = criterion(aux_outputs, labels)
                            loss = loss1 + 0.4 * loss2
                        else:
                            outputs = model(inputs)
                            loss = criterion(outputs, labels) # 计算loss

                        # backward + optimize only if in training phase
                        if phase == "train":
                            # 反向传播
                            loss.backward()
                            # 更新权值参数
                            optimizer.step()

                    tq.update(inputs.size(0))

                    # 一次迭代(step)的更新
                    running_loss_meter.add(loss.item())
                    running_acc_meter.add(F.softmax(output.detach(), dim=1), labels.detach())

                # 学习率调整(按epoch调整)
                if phase == "train":
                    # 更新学习率
                    scheduler.step()
                    save(epoch+1)

                tq.close()
                print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, running_loss_meter.value()[0], running_acc_meter.value()))

                # copy the bestmodel
                if phase == "val" and running_acc_meter.value() > best_acc:
                    best_acc = running_acc_meter.value()
                    shutil.copy(str(model_path), str(best_modelpath))

                """记录epoch的loss和acc,不记录step的"""
                # 记录到日志中
                logger.info("\n phase: {phase}, epoch: {epoch}, lr: {lr}, loss: {loss}, acc: {acc}".format(
                    phase = phase, epoch = epoch+1, lr = scheduler.get_lr(),
                    loss = running_loss_meter.value()[0], acc = running_acc_meter.value()))

                # ResultWriter记录
                result_writer.update(epoch, {"phase":phase, "loss": running_loss_meter.value()[0],
                    "acc":running_acc_meter.value()})

            except KeyboardInterrupt:
                tq.close()
                print("Ctrl+C", saving snapshot)
                save(epoch)

        print()

    # 训练所用时间
    time_elapsed = time_meter.value()
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed, time_elapsed))
    print("Best val Acc: {:.4f}".format(best_acc))
Beispiel #6
0
def construct_engine(*engine_args, checkpoint_iter_freq=None, checkpoint_epoch_freq=1, checkpoint_save_path='checkpoints',
                     iter_log_freq=100, topk=[1, 5], num_classes=1000,
                     lambda_error=0.7, environment='main', lr_points=[], server='localhost'):
    engine = Engine(*engine_args)


    # meters
    time_meter = meter.TimeMeter(1)
    data_loading_meter = meter.MovingAverageValueMeter(windowsize=100)
    gpu_time_meter = meter.MovingAverageValueMeter(windowsize=100)

    classerr_meter = meter.ClassErrorMeter(topk)
    train_loss_meter = meter.MovingAverageValueMeter(windowsize=100)
    test_loss_meter = meter.AverageValueMeter()
    ap_meter = APMeter(num_classes)

    # logger associated with meters
    data_loading_logger = VisdomPlotLogger('line', server=server, opts={'title': 'Data Loading Time'}, env=environment)
    gpu_time_logger = VisdomPlotLogger('line', server=server, opts={'title': 'Gpu Computing Time'}, env=environment)
    classerr_meter_iter_loggers = []
    classerr_meter_epoch_logers = []
    for i in range(len(topk)):
        classerr_meter_iter_loggers.append(
            VisdomPlotLogger('line', server=server, opts={'title': 'Classification Top {} Error Along Iterations'.format(topk[i])},
                             env=environment))
        classerr_meter_epoch_logers.append(
            VisdomPlotLogger('line', server=server, opts={'title': 'Classification Top {} Error Along Epochs'.format(topk[i])},
                             env=environment))
    loss_meter_iter_logger = VisdomPlotLogger('line', server=server, opts={'title': 'Loss in One Iteration'}, env=environment)
    loss_meter_epoch_logger = VisdomPlotLogger('line', server=server, opts={'title': 'Loss with Epoch'}, env=environment)
    test_loss_logger = VisdomPlotLogger('line', server=server, opts={'title': 'test loss'}, env=environment)
    test_error_logger = VisdomPlotLogger('line', server=server, opts={'title': 'test error'}, env=environment)
    weighted_error_log = VisdomPlotLogger('line', server=server, opts={'title': 'weighted test error'}, env=environment)
    ap_logger = visdom.Visdom(env=environment, server='http://'+server)

    def prepare_network(state):
        # switch model
        if state['train']:
            state['network'].train()
        else:
            state['network'].eval()

    def wrap_data(state):
        if state['gpu_ids'] is not None:
            state['sample'][0] = state['sample'][0].cuda(device=state['gpu_ids'][0], async=False)
            state['sample'][1] = state['sample'][1].cuda(device=state['gpu_ids'][0], async=True)

        volatile = False

        if not state['train']:
            volatile = True

        if volatile:
            with torch.no_grad():
                state['sample'][0] = Variable(data=state['sample'][0])
                state['sample'][1] = Variable(data=state['sample'][1])
        else:
            state['sample'][0] = Variable(data=state['sample'][0])
            state['sample'][1] = Variable(data=state['sample'][1])


    def on_start(state):
        if state['gpu_ids'] is None:
            print('Training/Validating without gpus ...')
        else:
            if not torch.cuda.is_available():
                raise RuntimeError('Cuda is not available')

            state['network'].cuda(state['gpu_ids'][0])
            state['distribution'] = state['distribution'].cuda(state['gpu_ids'][0])
            print('Training/Validating on gpu: {}'.format(state['gpu_ids']))

        if state['train']:
            print('*********************Start Training at {}***********************'.format(time.strftime('%c')))
            if state['t'] == 0:
                filename = os.path.join(checkpoint_save_path, 'init_model.pth.tar')
                save_model(state, filename)
        else:
            print('-------------Start Validation at {} For Epoch{}--------------'.format(time.strftime('%c'),
                                                                                         state['epoch']))
        prepare_network(state)
        reset_meters()

    def on_start_epoch(state):
        reset_meters()
        print('--------------Start Training at {} for Epoch{}-----------------'.format(time.strftime('%c'),
                                                                                       state['epoch']))
        time_meter.reset()
        prepare_network(state)

    def on_end_sample(state):
        state['sample'].append(state['train'])
        wrap_data(state)
        data_loading_meter.add(time_meter.value())

    def on_start_forward(state):
        time_meter.reset()

    def on_end_forward(state):
        classerr_meter.add(state['output'].data, state['sample'][1].data)
        ap_meter.add(state['output'].data, state['sample'][1].data)
        if state['train']:
            train_loss_meter.add(state['loss'].data.item())
        else:
            test_loss_meter.add(state['loss'].data.item())
            

    def on_end_update(state):
        gpu_time_meter.add(time_meter.value())
        if state['t'] % iter_log_freq == 0 and state['t'] != 0:
            data_loading_logger.log(state['t'], data_loading_meter.value()[0])
            gpu_time_logger.log(state['t'], gpu_time_meter.value()[0])
            loss_meter_iter_logger.log(state['t'], train_loss_meter.value()[0])
            for i in range(len(topk)):
                classerr_meter_iter_loggers[i].log(state['t'], classerr_meter.value(topk[i]))
        if checkpoint_iter_freq and state['t'] % checkpoint_iter_freq == 0:
            filename = os.path.join(checkpoint_save_path,
                                    'e' + str(state['epoch']) + 't' + str(state['t']) + '.pth.tar')
            save_model(state, filename)
        time_meter.reset()

    def on_end_epoch(state):
        for i in range(len(topk)):
            classerr_meter_epoch_logers[i].log(state['epoch'], classerr_meter.value()[i])
        loss_meter_epoch_logger.log(state['epoch'], train_loss_meter.value()[0])
        print('***************Epoch {} done: class error {}, loss {}*****************'.format(state['epoch'],
                                                                                              classerr_meter.value(),
                                                                                              train_loss_meter.value()))
        if checkpoint_epoch_freq and state['epoch'] % checkpoint_epoch_freq == 0:
            filename = os.path.join(checkpoint_save_path,
                                    'e' + str(state['epoch']) + 't' + str(state['t']) + '.pth.tar')
            save_model(state, filename)
            # calculate sorted indexes w.r.t distribution
            sort_indexes = numpy.argsort(state['distribution'].cpu().numpy())
            ap_logger.line(X=numpy.linspace(0, num_classes, num=num_classes, endpoint=False),
                           Y=ap_meter.value()[sort_indexes], opts={'title': 'AP Change E{}(Training)'.format(state['epoch'])},
                           win='trainap{}'.format(state['epoch']))
        # adjust learning rate
        if state['epoch'] in lr_points:
            adjust_learning_rate(state)

        reset_meters()

        # do validation at the end of epoch
        state['train'] = False
        engine.validate()
        state['train'] = True

    def on_end_test(state):
        test_loss_logger.log(state['epoch'], test_loss_meter.value()[0])
        pre_distribution = state['distribution'].cpu().numpy()
        weighted_error = pre_distribution / pre_distribution.sum() * (1 - ap_meter.value())
        weighted_error = weighted_error.sum()
        weighted_error_log.log(state['epoch'], weighted_error)
        if checkpoint_epoch_freq and state['epoch'] % checkpoint_epoch_freq == 0:
            # calculate sort indexes w.r.t distribution
            sort_indexes = numpy.argsort(pre_distribution)
            ap_logger.line(X=numpy.linspace(0, num_classes, num=num_classes, endpoint=False),
                           Y=ap_meter.value()[sort_indexes], opts={'title': 'AP Change E{}(Test)'.format(state['epoch'])},
                           win='testap{}'.format(state['epoch']))
        for v in classerr_meter.value():
            test_error_logger.log(state['epoch'], v)
        print('----------------Test epoch {} done: class error {}, loss {}------------------'.format(state['epoch'],
                                                                                                     classerr_meter.value(),
                                                                                                     test_loss_meter.value()))
        reset_meters()

    def on_end(state):
        t = time.strftime('%c')
        if state['train']:
            print('*********************Training done at {}***********************'.format(t))
        else:
            print('*********************Validation done at {}***********************'.format(t))

    def on_update_distribution(state):

        # set info w.r.t the boost setting
        save_file_name = 'weak-learner.pth.tar'

        # calculate distribution w.r.t ap
        pre_distribution = state['distribution'].cpu().numpy()
        error = pre_distribution / pre_distribution.sum() * (1 - ap_meter.value())
        error = lambda_error * error.sum()
        beta = error / (1 - error)
        distribution = pre_distribution * numpy.power(beta, ap_meter.value())

        # normalization
        distribution = distribution / distribution.sum() * num_classes

        print('==> Calculating distribution done.')

        vis = visdom.Visdom(env=environment, server='http://'+server)
        vis.bar(X=distribution, opts={'title': 'Distribution'})

        # update model
        model = state['network']
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        weak_learner = {'beta': beta,
                        'model': model.state_dict(),
                        'distribution': state['distribution'],
                        'ap': ap_meter.value(),
                        'loss': test_loss_meter.value(),
                        'classerr': classerr_meter.value()}

        torch.save(weak_learner, os.path.join(checkpoint_save_path, save_file_name))
        print('==>Loss: {}'.format(weak_learner['loss']))
        print('==>Class Error: {}'.format(classerr_meter.value()))
        print('==>Beta: {}'.format(beta))
        print('==>{} saved.'.format(save_file_name))

        reset_meters()

        init_network(state['network'])

        # update distribution
        distribution = distribution.astype(numpy.float32)
        if state['gpu_ids'] is not None:
            distribution = torch.from_numpy(distribution).cuda(state['gpu_ids'][0])
        state['distribution'] = distribution
        if 'beta' in state.keys():
            state.pop('beta')


    def reset_meters():
        time_meter.reset()
        classerr_meter.reset()
        train_loss_meter.reset()
        test_loss_meter.reset()
        ap_meter.reset()

    def save_model(state, filename):
        model = state['network']
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        torch.save({'model': model.state_dict(), 'distribution': state['distribution']}, filename)
        print('==>Model {} saved.'.format(filename))

    def adjust_learning_rate(state):
        optimizer = state['optimizer']
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1

        print('~~~~~~~~~~~~~~~~~~adjust learning rate~~~~~~~~~~~~~~~~~~~~')

    engine.hooks['on_start'] = on_start
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_sample'] = on_end_sample
    engine.hooks['on_start_forward'] = on_start_forward
    engine.hooks['on_end_forward'] = on_end_forward
    engine.hooks['on_end_update'] = on_end_update
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.hooks['on_end_test'] = on_end_test
    engine.hooks['on_end'] = on_end
    engine.hooks['on_update_distribution'] = on_update_distribution

    return engine