コード例 #1
0
    def train(self, epoch):
        self.model.train()

        train_loss = MovingAverageMeter()
        train_acc = AccuracyMeter()

        for i, (x, y) in enumerate(self.train_loader):
            x = Variable(x)
            y = Variable(y)

            if self.use_cuda:
                x = x.cuda()
                y = y.cuda()

            output = self.model(x)
            loss = F.cross_entropy(output, y)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            train_loss.update(float(loss.data))

            y_pred = output.data.max(dim=1)[1]
            correct = int(y_pred.eq(y.data).cpu().sum())
            train_acc.update(correct, x.size(0))

        return train_loss.average, train_acc.accuracy
コード例 #2
0
    def validate():
        net.eval()

        valid_loss = AverageMeter()
        valid_acc = AccuracyMeter()
        with torch.no_grad():
            for i, (x, y) in enumerate(valid_loader):
                x = x.to(device)
                y = y.to(device)

                output = net(x)
                loss = F.cross_entropy(output, y)

                pred = output.data.max(dim=1)[1]
                correct = int(pred.eq(y.data).cpu().sum())

                valid_loss.update(float(loss.data), number=x.size(0))
                valid_acc.update(correct, number=x.size(0))

        return valid_loss.average, valid_acc.accuracy
コード例 #3
0
    def validate(self):
        self.model.eval()

        valid_loss = AverageMeter()
        valid_acc = AccuracyMeter()

        with torch.no_grad():
            for i, (x, y) in enumerate(self.valid_loader):
                x = x.to(self.device)
                y = y.to(self.device)

                output = self.model(x)
                loss = F.cross_entropy(output, y)

                valid_loss.update(float(loss.data), x.size(0))

                y_pred = output.data.max(dim=1)[1]
                correct = int(y_pred.eq(y.data).cpu().sum())
                valid_acc.update(correct, x.size(0))

        return valid_loss.average, valid_acc.accuracy
コード例 #4
0
    def train():
        net.train()

        train_loss = AverageMeter()
        train_acc = AccuracyMeter()
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)

            output = net(x)
            loss = F.cross_entropy(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pred = output.data.max(dim=1)[1]
            correct = int(pred.eq(y.data).cpu().sum())

            train_loss.update(float(loss.data), number=x.size(0))
            train_acc.update(correct, number=x.size(0))

        return train_loss.average, train_acc.accuracy
コード例 #5
0
    def validate(self):
        self.model.eval()

        valid_loss = AverageMeter()
        valid_acc = AccuracyMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            x = Variable(x, volatile=True)
            y = Variable(y)

            if self.use_cuda:
                x = x.cuda()
                y = y.cuda()

            output = self.model(x)
            loss = F.cross_entropy(output, y)

            valid_loss.update(float(loss.data), x.size(0))

            y_pred = output.data.max(dim=1)[1]
            correct = int(y_pred.eq(y.data).cpu().sum())
            valid_acc.update(correct, x.size(0))

        return valid_loss.average, valid_acc.accuracy
コード例 #6
0
ファイル: main.py プロジェクト: ravinkohli/resnet
def model_train(model, config, criterion, trainloader, testloader, validloader,
                model_name):
    num_epochs = config['budget']
    success = False
    time_to_94 = None

    lrs = list()
    logging.info(f"weight decay:\t{config['weight_decay']}")
    logging.info(f"momentum :\t{config['momentum']}")

    base_optimizer = optim.SGD(model.parameters(),
                               lr=config['base_lr'],
                               weight_decay=config['weight_decay'],
                               momentum=config['momentum'])
    if config['swa']:
        optimizer = torchcontrib.optim.SWA(base_optimizer)

        # lr_scheduler = SWAResNetLR(optimizer, milestones=config['milestones'], schedule=config['schedule'], swa_start=config['swa_start'], swa_init_lr=config['swa_init_lr'], swa_step=config['swa_step'], base_lr=config['base_lr'])
    else:
        optimizer = base_optimizer
        # lr_scheduler = PiecewiseLinearLR(optimizer, milestones=config['milestones'], schedule=config['schedule'])

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, num_epochs)
    #lr_scheduler = PiecewiseLinearLR(optimizer, milestones=config['milestones'], schedule=config['schedule'])
    save_model_str = './models/'

    if not os.path.exists(save_model_str):
        os.mkdir(save_model_str)

    save_model_str += f'model_({datetime.datetime.now()})'
    if not os.path.exists(save_model_str):
        os.mkdir(save_model_str)

    summary_dir = f'{save_model_str}/summary'
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    c = datetime.datetime.now()
    train_meter = AccuracyMeter(model_dir=summary_dir, name='train')
    test_meter = AccuracyMeter(model_dir=summary_dir, name='test')
    valid_meter = AccuracyMeter(model_dir=summary_dir, name='valid')

    for epoch in range(num_epochs):
        lr = lr_scheduler.get_lr()[0]
        lrs.append(lr)

        logging.info('epoch %d, lr %e', epoch, lr)

        train_acc, train_obj, time = train(trainloader, model, criterion,
                                           optimizer, model_name,
                                           config['grad_clip'],
                                           config['prefetch'])

        train_meter.update({
            'acc': train_acc,
            'loss': train_obj
        }, time.total_seconds())
        lr_scheduler.step()
        if config['swa'] and ((epoch + 1) >= config['swa_start']) and (
            (epoch + 1 - config['swa_start']) % config['swa_step'] == 0):
            optimizer.update_swa()
        valid_acc, valid_obj, time = infer(testloader,
                                           model,
                                           criterion,
                                           name=model_name,
                                           prefetch=config['prefetch'])
        valid_meter.update({
            'acc': valid_acc,
            'loss': valid_obj
        }, time.total_seconds())
        if valid_acc >= 94:
            success = True
            time_to_94 = train_meter.time
            logging.info(f'Time to reach 94% {time_to_94}')
        # wandb.log({"Test Accuracy":valid_acc, "Test Loss": valid_obj, "Train Accuracy":train_acc, "Train Loss": train_obj})

    a = datetime.datetime.now() - c
    if config['swa']:
        optimizer.swap_swa_sgd()
        optimizer.bn_update(trainloader, model)
    test_acc, test_obj, time = infer(testloader,
                                     model,
                                     criterion,
                                     name=model_name,
                                     prefetch=config['prefetch'])
    test_meter.update({
        'acc': test_acc,
        'loss': test_obj
    }, time.total_seconds())
    torch.save(model.state_dict(), f'{save_model_str}/state')
    # wandb.save('model.h5')
    train_meter.plot(save_model_str)
    valid_meter.plot(save_model_str)

    plt.plot(lrs)
    plt.title('LR vs epochs')
    plt.xlabel('Epochs')
    plt.ylabel('LR')
    plt.xticks(np.arange(0, num_epochs, 5))
    plt.savefig(f'{save_model_str}/lr_schedule.png')
    plt.close()

    device = get('device')
    device_name = cpuinfo.get_cpu_info(
    )['brand'] if device.type == 'cpu' else torch.cuda.get_device_name(0)
    total_time = round(a.total_seconds(), 2)
    logging.info(
        f'test_acc: {test_acc}, save_model_str:{save_model_str}, total time :{total_time} and device used {device_name}'
    )
    _, cnt, time = train_meter.get()
    time_per_step = round(time / cnt, 2)
    return_dict = {
        'test_acc': test_acc,
        'save_model_str': save_model_str,
        'training_time_per_step': time_per_step,
        'total_train_time': time,
        'total_time': total_time,
        'device_used': device_name,
        'train_acc': train_acc
    }
    if success:
        return_dict['time_to_94'] = time_to_94
    return return_dict, model