Exemplo n.º 1
0
def evaluate(model, batches):
    model.eval()
    meters = collections.defaultdict(lambda: AverageMeter())
    with torch.no_grad():
        for inputs, targets in batches:
            losses = model.autoenc(inputs, targets)
            for k, v in losses.items():
                meters[k].update(v.item(), inputs.size(1))
    loss = model.loss({k: meter.avg for k, meter in meters.items()})
    meters['loss'].update(loss)
    return meters
Exemplo n.º 2
0
def main(args):
    # NTU Dataset
    dataset = NTU(
        root=args.root,
        w=args.width,
        h=args.height,
        t=args.time,
        dataset='train',
        train=True,
        avi_dir=args.avi_dir,
        usual_transform=False,
    )

    # Pytorch dataloader
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=args.cuda,
                                             collate_fn=my_collate)
    # Loop
    data_time = AverageMeter()
    start_data = time.time()
    for i, dict_input in enumerate(dataloader):
        duration_data = time.time() - start_data
        data_time.update(duration_data)

        # Get the data
        clip, skeleton = dict_input['clip'], dict_input[
            'skeleton']  # (B, C, T, 224, 224), (B, T, 2, 25, 2)
        # Show
        show_one_img(clip[0, :, 0], skeleton[0, 0])

        print("{}/{} : {time.val:.3f} ({time.avg:.3f}) sec/batch".format(
            i + 1, len(dataloader), time=data_time))
        sys.stdout.flush()
        start_data = time.time()
Exemplo n.º 3
0
def evaluate(model, batches):
    model.eval()
    meters = collections.defaultdict(lambda: AverageMeter())
    with torch.no_grad():
        for inputs, targets in batches:
            # mu, logvar, z, logits = model(inputs, True)

            # losses = model.loss(logits, targets, mu, logvar)
            losses = model.autoenc(inputs, targets)

            for k, v in losses.items():
                meters[k].update(v.item(), inputs.size(1))

    # losses = {k: meter.avg for k, meter in meters.items()}
    # loss = losses['rec'] + args.lambda_kl*losses['kl']
    loss = model.loss({k: meter.avg for k, meter in meters.items()})
    meters['loss'].update(loss)
    return meters
Exemplo n.º 4
0
    def init_train(self, con_weight: float = 1.0):

        test_img = self.get_test_image()
        meter = AverageMeter("Loss")
        self.writer.flush()
        lr_scheduler = OneCycleLR(self.optimizer_G,
                                  max_lr=0.9999,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=self.init_train_epoch)

        for g in self.optimizer_G.param_groups:
            g['lr'] = self.init_lr

        for epoch in tqdm(range(self.init_train_epoch)):

            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)

                generator_output = self.G(train)
                # content_loss = loss.reconstruction_loss(generator_output, train) * con_weight
                content_loss = self.loss.content_loss(generator_output,
                                                      train) * con_weight
                # content_loss = F.mse_loss(train, generator_output) * con_weight
                content_loss.backward()
                self.optimizer_G.step()
                lr_scheduler.step()

                meter.update(content_loss.detach())

            self.writer.add_scalar(f"Loss : {self.init_time}",
                                   meter.sum.item(), epoch)
            self.write_weights(epoch + 1, write_D=False)
            self.eval_image(epoch, f'{self.init_time} reconstructed img',
                            test_img)

        for g in self.optimizer_G.param_groups:
            g['lr'] = self.G_lr
Exemplo n.º 5
0
def validate(val_loader, model, criterion, epoch, start_time):
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    eval_start_time = time.time()

    for i, (input, target) in enumerate(val_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        if args.distributed:
            top1acc, top5acc, loss, batch_total = distributed_predict(
                input, target, model, criterion)
        else:
            with torch.no_grad():
                output = model(input)
                loss = criterion(output, target).data
            batch_total = input.size(0)
            top1acc, top5acc = accuracy(output.data, target, topk=(1, 5))

        # Eval batch done. Logging results
        timer.batch_end()
        losses.update(to_python_float(loss), to_python_float(batch_total))
        top1.update(to_python_float(top1acc), to_python_float(batch_total))
        top5.update(to_python_float(top5acc), to_python_float(batch_total))
        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(val_loader))
        if args.local_rank == 0 and should_print:
            output = (
                f'Test:  [{epoch}][{batch_num}/{len(val_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})')
            log.verbose(output)

    tb.log_eval(top1.avg, top5.avg, time.time() - eval_start_time)
    tb.log('epoch', epoch)

    return top1.avg, top5.avg
Exemplo n.º 6
0
    def process(self):
        acc = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()
        log_file = os.path.join(self.data_folder, 'test.log')
        logger = Logger('test', log_file)
        # switch to evaluate mode
        self.model.eval()

        start_time = time.clock()
        print("Begin testing")
        predicted, probs = [], []
        for i, (images, labels) in enumerate(self.test_loader):

            if check_gpu() > 0:
                images = images.cuda(async=True)
                labels = labels.cuda(async=True)
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)

            if self.tencrop:
                # Due to ten-cropping, input batch is a 5D Tensor
                batch_size, number_of_crops, number_of_channels, height, width = images.size(
                )

                # Fuse batch size and crops
                images = images.view(-1, number_of_channels, height, width)

                # Compute model output
                output_batch_crops = self.model(images)

                # Average predictions for each set of crops
                output_batch = output_batch_crops.view(batch_size,
                                                       number_of_crops,
                                                       -1).mean(1)
                label_repeated = labels.repeat(10, 1).transpose(
                    1, 0).contiguous().view(-1, 1).squeeze()
                loss = self.criterion(output_batch_crops, label_repeated)
            else:
                output_batch = self.model(images)
                loss = self.criterion(output_batch, labels)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output_batch.data, labels, topk=(1, 5))
            #     print(prec1, prec5)
            losses.update(loss.item(), images.size(0))
            acc.update(prec1.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

            if i % self.print_freq == 0:
                print('TestVal: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(self.test_loader),
                          loss=losses,
                          top1=top1,
                          top5=top5))

        print(
            ' * Accuracy {acc.avg:.3f}  Acc@5 {top5.avg:.3f} Loss {loss.avg:.3f}'
            .format(acc=acc, top5=top5, loss=losses))

        end_time = time.clock()
        print("Total testing time %.2gs" % (end_time - start_time))
        logger.info("Total testing time %.2gs" % (end_time - start_time))
        logger.info(' * Accuracy {acc.avg:.3f} Loss {loss.avg:.3f}'.format(
            acc=acc, top5=top5, loss=losses))
Exemplo n.º 7
0
def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            model.zero_grad()
            loss.backward()
            model_grads_to_master_grads(model_params, master_params)
            for param in master_params:
                param.grad.data = param.grad.data / args.loss_scale
            optimizer.step()
            master_params_to_model_params(model_params, master_params)
            loss = loss / args.loss_scale
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()
        corr1, corr5 = correct(output.data, target, topk=(1, 5))
        reduced_loss, batch_total = to_python_float(
            loss.data), to_python_float(input.size(0))
        if args.distributed:  # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
            metrics = torch.tensor([batch_total, reduced_loss, corr1,
                                    corr5]).float().cuda()
            batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(
                metrics).cpu().numpy()
            reduced_loss = reduced_loss / dist_utils.env_world_size()
        top1acc = to_python_float(corr1) * (100.0 / batch_total)
        top5acc = to_python_float(corr5) * (100.0 / batch_total)

        losses.update(reduced_loss, batch_total)
        top1.update(top1acc, batch_total)
        top5.update(top5acc, batch_total)

        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(trn_loader))
        if args.local_rank == 0 and should_print:
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val,
                             input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)

            output = (
                f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

        tb.update_step_count(batch_total)
Exemplo n.º 8
0
    def train(self, logger, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        acc = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        rate = get_learning_rate(self.optimizer)[0]
        # switch to train mode
        self.model.train()

        end = time.time()
        for i, (images, target) in enumerate(self.train_loader):
            # adjust learning rate scheduler step
            self.scheduler.batch_step()

            # measure data loading time
            data_time.update(time.time() - end)
            if check_gpu() > 0:
                images = images.cuda(async=True)
                target = target.cuda(async=True)
            image_var = torch.autograd.Variable(images)
            label_var = torch.autograd.Variable(target)

            self.optimizer.zero_grad()

            # compute y_pred
            y_pred = self.model(image_var)
            if self.model_type == 'I3D':
                y_pred = y_pred[0]

            loss = self.criterion(y_pred, label_var)
            # measure accuracy and record loss
            prec1, prec5 = accuracy(y_pred.data, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            acc.update(prec1.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))
            # compute gradient and do SGD step

            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Lr {rate:.5f}\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          epoch,
                          self.epochs,
                          i,
                          len(self.train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          rate=rate,
                          loss=losses,
                          top1=top1,
                          top5=top5))

        logger.info('Epoch: [{0}/{1}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Lr {rate:.5f}\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch,
                        self.epochs,
                        batch_time=batch_time,
                        data_time=data_time,
                        rate=rate,
                        loss=losses,
                        top1=top1,
                        top5=top5))
        return losses, acc
Exemplo n.º 9
0
    def validate(self, logger):
        batch_time = AverageMeter()
        losses = AverageMeter()
        acc = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        for i, (images, labels) in enumerate(self.val_loader):
            if check_gpu() > 0:
                images = images.cuda(async=True)
                labels = labels.cuda(async=True)

            image_var = torch.autograd.Variable(images)
            label_var = torch.autograd.Variable(labels)

            # compute y_pred
            y_pred = self.model(image_var)
            if self.model_type == 'I3D':
                y_pred = y_pred[0]

            loss = self.criterion(y_pred, label_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(y_pred.data, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            acc.update(prec1.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                print('TrainVal: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(self.val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))

        print(' * Accuracy {acc.avg:.3f}  Loss {loss.avg:.3f}'.format(
            acc=acc, loss=losses))
        logger.info(' * Accuracy {acc.avg:.3f}  Loss {loss.avg:.3f}'.format(
            acc=acc, loss=losses))

        return losses, acc
Exemplo n.º 10
0
    def __init__(self,
                 MovementModule='default',
                 EnvModule='default',
                 TextModule='default',
                 WindowHeight=720,
                 WindowWidth=1080,
                 MaxTextNum=15,
                 DataStoragePath='../../../GeneratedData/DataFraction_1',
                 camera_anchor_filepath='./camera_anchors/urbancity.txt',
                 EnvName='',
                 anchor_freq=10,
                 max_emissive=5,
                 FontSize=[8, 16],
                 use_real_img=0.1,
                 is_debug=True,
                 languages=["Latin"],
                 HighResFactor=2.0,
                 UnrealProjectName="./",
                 **kwargs):
        self.client = WrappedClient(UnrealCVClient, DataStoragePath,
                                    HighResFactor, UnrealProjectName)
        self.DataStoragePath = DataStoragePath
        self.UnrealProjectName = UnrealProjectName
        self.WindowHeight = WindowHeight
        self.WindowWidth = WindowWidth
        self.MaxTextNum = MaxTextNum
        self.is_debug = is_debug
        self.camera_anchor_filepath = camera_anchor_filepath
        self.anchor_freq = anchor_freq
        self.HighResFactor = HighResFactor

        self.RootPath = opa(DataStoragePath)
        while os.path.isdir(self.RootPath):
            root_path, count = self.RootPath.split('_')
            self.RootPath = root_path + '_' + str(int(count) + 1)
        print(f"Data will be saved to: {self.RootPath}")
        self.LabelPath = osp.join(self.RootPath, 'Label.json')
        self.DataLabel = None
        self.ImgFolder = osp.join(self.RootPath, 'imgs')
        self.LabelFolder = osp.join(self.RootPath, 'labels')
        self.WordFolder = osp.join(self.RootPath, 'WordCrops')
        self.DataCount = 0
        self.isConnected = False
        self.SaveFreq = 100

        # step 1
        self._InitializeDataStorage()
        # step 2
        if len(EnvName) > 0:
            StartEngine(EnvName)
        self._ConnectToGame()
        # step 3 set resolution & rotation
        self.client.setres(self.WindowWidth, self.WindowHeight)
        # self.client.setCameraRotation(0, 0, 0)
        self.EnvDepth = kwargs.get('EnvDepth', 100)
        # load modules
        self.Wanderer = CameraSet[MovementModule](
            client=self.client,
            camera_anchor_filepath=self.camera_anchor_filepath,
            anchor_freq=self.anchor_freq)
        self.EnvRenderer = EnvSet[EnvModule](client=self.client)
        self.TextPlacer = TextPlacement[TextModule](
            client=self.client,
            MaxTextCount=self.MaxTextNum,
            ContentPath=osp.join(self.RootPath, 'WordCrops'),
            max_emissive=max_emissive,
            FontSize=FontSize,
            is_debug=is_debug,
            use_real_img=use_real_img,
            languages=languages,
            HighResFactor=HighResFactor)

        # initializer meters
        self.camera_meter = AverageMeter()
        self.env_meter = AverageMeter()
        self.text_meter = AverageMeter()
        self.retrieve_label_meter = AverageMeter()
        self.save_label_meter = AverageMeter()
        self.save_meter = AverageMeter()

        self._cleanup()
Exemplo n.º 11
0
    def validate(self, logger):
        batch_time = AverageMeter()
        losses = AverageMeter()
        acc = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        for i, (images, labels) in enumerate(self.val_loader):
            if check_gpu() > 0:
                images = images.cuda(async=True)
                labels = labels.cuda(async=True)
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)

            if self.tencrop:
                # Due to ten-cropping, input batch is a 5D Tensor
                batch_size, number_of_crops, number_of_channels, height, width = images.size(
                )

                # Fuse batch size and crops
                images = images.view(-1, number_of_channels, height, width)

                # Compute model output
                output_batch_crops = self.model(images)

                # Average predictions for each set of crops
                output_batch = output_batch_crops.view(batch_size,
                                                       number_of_crops,
                                                       -1).mean(1)
                label_repeated = labels.repeat(10, 1).transpose(
                    1, 0).contiguous().view(-1, 1).squeeze()
                loss = self.criterion(output_batch_crops, label_repeated)
            else:
                output_batch = self.model(images)
                loss = self.criterion(output_batch, labels)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output_batch.data, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            acc.update(prec1.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                print('TrainVal: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(self.val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))

        print(' * Accuracy {acc.avg:.3f}  Loss {loss.avg:.3f}'.format(
            acc=acc, loss=losses))
        logger.info(' * Accuracy {acc.avg:.3f}  Loss {loss.avg:.3f}'.format(
            acc=acc, loss=losses))

        return losses, acc
Exemplo n.º 12
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    """
    Train a proposed ResNet for classification
    :param train_loader: default data_loader in pytorch
    :param model: The proposed model, ResNet for our serup
    :param criterion: criterion(loss) for optimization purpose
    :param optimizer: to optimize model, adam or sgd is recommended
    :param epoch: How many turns to train whole training set around
    :param args: arguments for user input
    :return:
    """
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            progress.display(i)
Exemplo n.º 13
0
def main(args):
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    log_file = os.path.join(args.save_dir, 'log.txt')
    logging(str(args), log_file)

    # Prepare data
    train_sents = load_sent(args.train)
    logging(
        '# train sents {}, tokens {}'.format(len(train_sents),
                                             sum(len(s) for s in train_sents)),
        log_file)
    valid_sents = load_sent(args.valid)
    logging(
        '# valid sents {}, tokens {}'.format(len(valid_sents),
                                             sum(len(s) for s in valid_sents)),
        log_file)
    vocab_file = os.path.join(args.save_dir, 'vocab.txt')

    # if not os.path.isfile(vocab_file):
    #     Vocab.build(train_sents, vocab_file, args.vocab_size)

    Vocab.build(train_sents, vocab_file, args.vocab_size)

    vocab = Vocab(vocab_file)
    logging('# vocab size {}'.format(vocab.size), log_file)

    set_seed(args.seed)
    cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    model = {
        'dae': DAE,
        'vae': VAE,
        'aae': AAE
    }[args.model_type](vocab, args).to(device)
    if args.load_model:
        ckpt = torch.load(args.load_model)
        model.load_state_dict(ckpt['model'])
        model.flatten()
    logging(
        '# model parameters: {}'.format(
            sum(x.data.nelement() for x in model.parameters())), log_file)

    train_batches, _ = get_batches(train_sents, vocab, args.batch_size, device)
    valid_batches, _ = get_batches(valid_sents, vocab, args.batch_size, device)
    best_val_loss = None
    for epoch in range(args.epochs):
        start_time = time.time()
        logging('-' * 80, log_file)
        model.train()
        meters = collections.defaultdict(lambda: AverageMeter())
        indices = list(range(len(train_batches)))
        random.shuffle(indices)
        for i, idx in enumerate(indices):
            inputs, targets = train_batches[idx]
            losses = model.autoenc(inputs, targets, is_train=True)
            losses['loss'] = model.loss(losses)
            model.step(losses)
            for k, v in losses.items():
                meters[k].update(v.item())

            if (i + 1) % args.log_interval == 0:
                log_output = '| epoch {:3d} | {:5d}/{:5d} batches |'.format(
                    epoch + 1, i + 1, len(indices))
                for k, meter in meters.items():
                    log_output += ' {} {:.2f},'.format(k, meter.avg)
                    meter.clear()
                logging(log_output, log_file)

        valid_meters = evaluate(model, valid_batches)
        logging('-' * 80, log_file)
        log_output = '| end of epoch {:3d} | time {:5.0f}s | valid'.format(
            epoch + 1,
            time.time() - start_time)
        for k, meter in valid_meters.items():
            log_output += ' {} {:.2f},'.format(k, meter.avg)
        if not best_val_loss or valid_meters['loss'].avg < best_val_loss:
            log_output += ' | saving model'
            ckpt = {'args': args, 'model': model.state_dict()}
            torch.save(ckpt, os.path.join(args.save_dir, 'model.pt'))
            best_val_loss = valid_meters['loss'].avg
        logging(log_output, log_file)
    logging('Done training', log_file)
Exemplo n.º 14
0
                        batch_size=opt.batch_size,
                        shuffle=True,
                        pin_memory=opt.cuda,
                        drop_last=True,
                        num_workers=opt.num_workers)

    # setup tracker
    tracker = TrackerSiamFC(name=opt.name,
                            weight=opt.weight,
                            device=opt.device)

    # training loop
    itr = 0
    num_itrs = int((opt.num_epochs * len(loader)) / opt.print_freq) + 1
    loss_logger = Logger(os.path.join(opt.log_dir, 'loss.csv'), num_itrs)
    loss_meter = AverageMeter()
    for epoch in range(opt.num_epochs):
        for step, batch in enumerate(loader):
            loss = tracker.step(batch, backward=True, update_lr=(step == 0))

            itr += 1
            loss_meter.update(loss)
            if itr % opt.print_freq == 0:
                print('Epoch [{}/{}] itr [{}]: Loss: {:.5f}'.format(
                    epoch + 1, opt.num_epochs, itr, loss_meter.avg))
                sys.stdout.flush()

                loss_logger.set(itr / opt.print_freq, loss_meter.avg)
                loss_meter = AverageMeter()

        # save checkpoint
Exemplo n.º 15
0
def train_walk(walk_file,
               w,
               data_batches,
               valid_batches,
               model,
               num_epochs,
               verbose=False):
    # for param in model.parameters():
    #     param.requires_grad = False # freeze the model
    print("START TRAINING:", walk_file)
    opt = optim.SGD([w], lr=0.01)
    # opt = optim.Adam([w], lr=0.01, momentum=0.9)
    start_time = time.perf_counter()
    meter = AverageMeter()

    loss_hist_before = []
    loss_hist_during = []
    for e in range(num_epochs):
        avg_loss_before = average_loss(w, data_batches, model, verbose)
        model.train()
        total_loss = 0
        nsents = 0
        meter.clear()
        indices = list(range(len(data_batches)))
        random.shuffle(indices)
        for i, idx in enumerate(indices):
            opt.zero_grad()
            x, x_edit = data_batches[idx]

            # encode the input x
            mu, logvar = model.encode(x)
            z = reparameterize(mu, logvar)
            # add w to compute new latent
            new_latent = z + alpha * w
            # decode the new latent
            logits, hidden = model.decode(new_latent, x)
            # compute the loss wrt to the edit
            loss = model.loss_rec(logits, x_edit).mean()
            #print("LOSS", idx, ":", loss)

            loss.backward()
            opt.step()
            total_loss += loss * x.shape[1]
            nsents += x.shape[1]
            meter.update(loss, x.shape[1])

        print("---------------------------")
        avg_loss_after = average_loss(w, data_batches, model)
        print("FINISHED EPOCH", e)
        print("avg loss before:", avg_loss_before)
        print("avg train loss: ", total_loss / nsents)
        # print("meter loss", meter.avg)
        loss_hist_before.append((e, avg_loss_before.item()))
        loss_hist_during.append((e, meter.avg.item()))
        if verbose:
            print("loss", loss)
            print("nsents", nsents)
        val_loss = average_loss(w, valid_batches, model, False)
        print("avg valid loss: ", val_loss)
        epoch_time = time.perf_counter()
        print("time: ", epoch_time - start_time)
        print("=" * 60)
        #print(torch.cuda.memory_summary(device=None, abbreviated=False))

    print("FINISHED TRAINING")
    best_before_loss = min(loss_hist_before, key=lambda x: x[1])
    best_during_loss = min(loss_hist_during, key=lambda x: x[1])
    print("best_before_loss:", best_before_loss,
          loss_hist_during[best_before_loss[0]])
    print("best_during_loss:", best_during_loss,
          loss_hist_before[best_during_loss[0]])
    plot_series([loss_hist_before, loss_hist_during], walk_file)
    print(w)
    torch.save(w, results_dir + walk_file)
    return w
Exemplo n.º 16
0
def average_loss(w, data_batches, model, verbose=False):
    meter = AverageMeter()
    model.eval()
    with torch.no_grad():
        total_loss = 0
        B = len(data_batches)
        nsents = 0
        for idx in range(len(data_batches)):
            x, x_edit = data_batches[idx]

            mu, logvar = model.encode(x)
            z = reparameterize(mu, logvar)
            new_latent = z + alpha * w
            logits, hidden = model.decode(new_latent, x)
            loss = model.loss_rec(logits, x_edit).mean()

            if verbose:
                # losses = model.autoenc(x, x_edit)
                # print("autoenc", idx, ":", losses['rec'], "shapes", x.shape, x_edit.shape)
                print("my loss", idx, ":", loss)
                print("x", x.shape, "| x_edit", x_edit.shape)
                sents = []
                edited_sents = []
                walk_sents = []
                batch_len = x.shape[1]

                max_len = 35
                dec = 'greedy'
                outputs = model.generate(new_latent, max_len, dec).t()
                for i in range(batch_len):
                    x_i = x[:, i]
                    sents.append([vocab.idx2word[id] for id in x_i])
                    xe_i = x_edit[:, i]
                    edited_sents.append([vocab.idx2word[id] for id in xe_i])
                    output_i = outputs[i]
                    walk_sents.append([vocab.idx2word[id] for id in output_i])

                for i in range(batch_len):
                    x_i = torch.unsqueeze(x[:, i], dim=1)
                    xe_i = torch.unsqueeze(x_edit[:, i], dim=1)
                    loss_i = compute_loss(w, x_i, xe_i, model)
                    print("batch", idx, ":", loss, "| sentence", i, ":",
                          loss_i)
                    print("--SENT:", sents[i])
                    print(x[:, i])
                    print("--EDIT:", edited_sents[i])
                    print(x_edit[:, i])
                    print("--WALK:", walk_sents[i])
                    print(outputs[i])

                if print_outputs_flag:

                    if idx == 4:
                        print("batch", idx, "length", x.shape[1])
                        edited_sents = []
                        walked_sents = []
                        sents = []

                        max_len = 35
                        dec = 'greedy'
                        outputs = model.generate(new_latent, max_len, dec).t()
                        print("outputs", outputs.shape)
                        print("x", x.shape)
                        print("x_edit", x_edit.shape)
                        print("z", z.shape)

                        for i in range(batch_len):
                            output_i = outputs[i]
                            walked_sents.append(
                                [vocab.idx2word[id] for id in output_i])
                            x_i = x[:, i]
                            sents.append([vocab.idx2word[id] for id in x_i])
                            xe_i = x_edit[:, i]
                            edited_sents.append(
                                [vocab.idx2word[id] for id in xe_i])

                        walked_sents = strip_eos(walked_sents)
                        edited_sents = strip_eos(edited_sents)
                        sents = strip_eos(sents)

                        for i in range(batch_len):
                            print(i)
                            print("--SENT:", sents[i])
                            print("--EDIT:", edited_sents[i])
                            print("--WALK:", walked_sents[i])

            total_loss += loss * x.shape[1]
            nsents += x.shape[1]
            #breakpoint()
            meter.update(loss.item(), x.shape[1])

        avg_loss = total_loss / nsents
        if verbose:
            print("avg_loss meter loss vs avg_loss", meter.avg, avg_loss)
        #print("average loss", avg_loss)
        #print("=" * 60)
    return avg_loss
Exemplo n.º 17
0
Arquivo: main.py Projeto: momiji7/vae
    sch_encoder.load_state_dict(checkpoint['sch_encoder'])
    sch_decoder.load_state_dict(checkpoint['sch_decoder'])
    logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
        logger.last_info(), checkpoint['epoch']))
else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch = 0

for ep in range(start_epoch, args.epoch):
    sch_encoder.step()
    sch_decoder.step()

    vae_encoder.train()
    vae_decoder.train()

    encoder_losses = AverageMeter()
    decoder_losses = AverageMeter()
    # train
    for ibatch, (img, noise) in enumerate(train_loader):

        img = img.cuda()
        noise = noise.cuda()

        z_mu, z_sig = vae_encoder(img)  # z_mu , z_sig : N*1*Dz
        zl = torch.exp(0.5 * z_sig) * noise + z_mu  # zl           : N*L*Dz
        x_ber = vae_decoder(zl)  # N*L*Dx

        #encoder_loss, decoder_loss = vae_loss(img, x_mu, x_sig, z_mu, z_sig)

        encoder_loss, decoder_loss = vae_loss(img, x_ber, z_mu, z_sig)
Exemplo n.º 18
0
def validate(val_loader, model, criterion, args):
    """
    Call vaidate() to validate your result. Load validate data accordingly
    :param val_loader: default data loader to load validation data
    :param model: ResNet by default
    :param criterion: the loss to compute
    :param args: User defined input
    :return:
    """
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % args.print_freq == 0:
                progress.display(i)

        print(
            ' Validation finished! Avg stats: Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
            .format(top1=top1, top5=top5))

    return top1.avg, top5.avg
Exemplo n.º 19
0
    def valid(cur_train_epoch, phase="valid", extract_features=False):
        '''

        :param cur_train_epoch:
        :param phase: "valid" or "test"
        :return:
        '''
        assert phase in ["valid", "test"]

        results = []
        valid_detail_meters = {
            "loss": SumMeter(),
            "model_loss": SumMeter(),
            "tp": SumMeter(),
            "fn": SumMeter(),
            "fp": SumMeter(),
            "tn": SumMeter(),
            "batch_time": AverageMeter(),
            "io_time": AverageMeter(),
        }

        if phase == "valid":
            logging.info("Valid data.")
            dataset = valid_dataset
        else:
            logging.info("Test data.")
            dataset = test_dataset

        model.eval()
        logging.info("Set network to eval model")

        if extract_features:
            features = np.zeros(shape=(dataset.original_len,
                                       model.n_output_feat),
                                dtype=np.float32)
            features_ctr = 0

        batch_idx = 0

        # chunked here
        chunk_size = 200
        n_chunk = (dataset.original_len +
                   (cfg[phase]["batch_size"] * chunk_size) -
                   1) // (cfg[phase]["batch_size"] * chunk_size)
        n_batch = (dataset.original_len + cfg[phase]["batch_size"] -
                   1) // cfg[phase]["batch_size"]

        for chunk_idx in range(n_chunk):
            s = chunk_idx * cfg[phase]["batch_size"] * chunk_size
            e = (chunk_idx + 1) * cfg[phase]["batch_size"] * chunk_size

            dataloader = DataLoader(
                dataset.slice(s, e),
                batch_size=cfg[phase]["batch_size"],
                shuffle=False,
                num_workers=cfg[phase]["n_worker"],
                collate_fn=dataset.get_collate_func(),
                pin_memory=True,
                drop_last=False,
            )

            batch_time_s = time.time()
            for samples in dataloader:
                batch_idx = batch_idx + 1
                cur_batch = batch_idx
                valid_detail_meters["io_time"].update(time.time() -
                                                      batch_time_s)

                # move to gpu
                samples = to_gpu_variable(samples, volatile=True)

                # forward
                loss, output, model_loss, reg_loss, d = model(samples)

                if phase == "valid":

                    # evaluate metrics
                    valid_detail_meters["loss"].update(
                        loss.data[0] * samples["size"], samples["size"])
                    valid_detail_meters["model_loss"].update(
                        model_loss.data[0] * samples["size"], samples["size"])
                    tp, fp, fn, tn, scores = evaluate(
                        output.data, samples["labels"].data,
                        samples["label_weights"].data)
                    #print(tp,fn,fp,tn)
                    valid_detail_meters["tp"].update(tp, samples["size"])
                    valid_detail_meters["fp"].update(fp, samples["size"])
                    valid_detail_meters["fn"].update(fn, samples["size"])
                    valid_detail_meters["tn"].update(tn, samples["size"])
                    # the large the better
                    tp_rate = valid_detail_meters["tp"].sum / (
                        valid_detail_meters["tp"].sum +
                        valid_detail_meters["fn"].sum + 1e-20)
                    # the smaller the better
                    fp_rate = valid_detail_meters["fp"].sum / (
                        valid_detail_meters["fp"].sum +
                        valid_detail_meters["tn"].sum + 1e-20)
                    valid_detail_meters["batch_time"].update(time.time() -
                                                             batch_time_s)
                    batch_time_s = time.time()
                else:
                    scores = torch.sigmoid(output.data)
                    valid_detail_meters["batch_time"].update(time.time() -
                                                             batch_time_s)
                    batch_time_s = time.time()

                # collect results
                uids = samples["uids"]
                aids = samples["aids"]
                results.extend(zip(aids, uids, scores))

                # collect features
                if extract_features:
                    bs = samples["size"]
                    features[features_ctr:features_ctr +
                             bs, :] = d.data.cpu().numpy()
                    features_ctr += bs

                # log results
                if phase == "valid":
                    if cur_batch % cfg["valid"]["logging_freq"] == 0:
                        logging.info(
                            "Valid Batch [{cur_batch}/{ed_batch}] "
                            "loss: {loss} "
                            "model_loss: {model_loss} "
                            "tp: {tp} fn: {fn} fp: {fp} tn: {tn} "
                            "tp_rate: {tp_rate} fp_rate: {fp_rate} "
                            "io time: {io_time}s batch time {batch_time}s".
                            format(
                                cur_batch=cur_batch,
                                ed_batch=n_batch,
                                loss=valid_detail_meters["loss"].mean,
                                model_loss=valid_detail_meters["model_loss"].
                                mean,
                                tp=valid_detail_meters["tp"].sum,
                                fn=valid_detail_meters["fn"].sum,
                                fp=valid_detail_meters["fp"].sum,
                                tn=valid_detail_meters["tn"].sum,
                                tp_rate=tp_rate,
                                fp_rate=fp_rate,
                                io_time=valid_detail_meters["io_time"].mean,
                                batch_time=valid_detail_meters["batch_time"].
                                mean,
                            ))
                else:
                    if cur_batch % cfg["test"]["logging_freq"] == 0:
                        logging.info(
                            "Test Batch [{cur_batch}/{ed_batch}] "
                            "io time: {io_time}s batch time {batch_time}s".
                            format(
                                cur_batch=cur_batch,
                                ed_batch=n_batch,
                                io_time=valid_detail_meters["io_time"].mean,
                                batch_time=valid_detail_meters["batch_time"].
                                mean,
                            ))

        if phase == "valid":
            logging.info("{phase} for {cur_train_epoch} train epoch "
                         "loss: {loss} "
                         "model_loss: {model_loss} "
                         "tp_rate: {tp_rate} fp_rate: {fp_rate} "
                         "io time: {io_time}s batch time {batch_time}s".format(
                             phase=phase,
                             cur_train_epoch=cur_train_epoch,
                             loss=valid_detail_meters["loss"].mean,
                             model_loss=valid_detail_meters["model_loss"].mean,
                             tp_rate=tp_rate,
                             fp_rate=fp_rate,
                             io_time=valid_detail_meters["io_time"].mean,
                             batch_time=valid_detail_meters["batch_time"].mean,
                         ))

            # write results to file
            res_fn = "{}_{}".format(cfg["valid_res_fp"], cur_train_epoch)
            with open(res_fn, 'w') as f:
                f.write("aid,uid,score\n")
                for res in results:
                    f.write("{},{},{:.8f}\n".format(res[0], res[1], res[2]))
            # evaluate results
            avg_auc, aucs = cal_avg_auc(res_fn, cfg["valid_fp"])
            logging.info("Valid for {cur_train_epoch} train epoch "
                         "average auc {avg_auc}".format(
                             cur_train_epoch=cur_train_epoch,
                             avg_auc=avg_auc,
                         ))
            logging.info("aucs: ")
            logging.info(pprint.pformat(aucs))

        else:
            logging.info(
                "Test for {} train epoch ends.".format(cur_train_epoch))

            res_fn = "{}_{}".format(cfg["test_res_fp"], cur_train_epoch)
            with open(res_fn, 'w') as f:
                f.write("aid,uid,score\n")
                for res in results:
                    f.write("{},{},{:.8f}\n".format(res[0], res[1], res[2]))

        # extract features
        if extract_features:
            import pickle as pkl
            with open(cfg["extracted_features_fp"], "wb") as f:
                pkl.dump(features, f, protocol=pkl.HIGHEST_PROTOCOL)
Exemplo n.º 20
0
def main():
    # load users, ads and their information of features
    users, ads, u_feat_infos, a_feat_infos = load_users_and_ads(
        cfg["data"]["user_fn"],
        cfg["data"]["ad_fn"],
        cfg["data"]["user_fi_fn"],
        cfg["data"]["ad_fi_fn"],
    )

    r_feat_infos = load_feature_infos(cfg["data"]["r_fi_fp"])

    logging.info("There are {} users.".format(len(users)))
    logging.info("There are {} ads.".format(len(ads)))

    # load data list and history features
    if not args.test:
        train_list = load_data_list(cfg["train_fp"])
        #print("train list len:",len(train_list))
        train_rfeats = load_rfeats(cfg["data"]["train_rfeat_fp"])
        valid_list = load_data_list(cfg["valid_fp"])
        valid_rfeats = load_rfeats(cfg["data"]["valid_rfeat_fp"])
    else:
        test_list = load_data_list(cfg["test_fp"])
        test_rfeats = load_rfeats(cfg["data"]["test_rfeat_fp"])

    filter = cfg["feat"]["filter"]

    # construct mappng and filter
    [fi.construct_mapping() for fi in u_feat_infos]
    [fi.construct_mapping() for fi in a_feat_infos]
    [fi.construct_mapping() for fi in r_feat_infos]

    # filter out low-frequency features.
    for fi in u_feat_infos:
        if fi.name in filter:
            fi.construct_filter(l_freq=filter[fi.name])
        else:
            fi.construct_filter(l_freq=0)
    logging.warning("Users Filtering!!!")

    for fi in a_feat_infos:
        if fi.name in filter:
            fi.construct_filter(l_freq=filter[fi.name])
        else:
            fi.construct_filter(l_freq=0)
    logging.warning("Ads Filtering!!!")

    reg = cfg["reg"]

    if not args.test:
        train_dataset = DatasetYouth(users,
                                     u_feat_infos,
                                     ads,
                                     a_feat_infos,
                                     train_rfeats,
                                     r_feat_infos,
                                     train_list,
                                     cfg["feat"]["u_enc"],
                                     cfg["feat"]["a_enc"],
                                     cfg["feat"]["r_enc"],
                                     reg=reg,
                                     pos_weight=cfg["train"]["pos_weight"],
                                     has_label=True)

        #print("train num: ",train_dataset.original_len)

        if cfg["train"]["use_radio_sampler"]:
            radio_sampler = RadioSampler(train_dataset,
                                         p2n_radio=cfg["train"]["p2n_radio"])
            logging.info("Using radio sampler with p:n={}".format(
                cfg["train"]["p2n_radio"]))

        valid_dataset = DatasetYouth(users,
                                     u_feat_infos,
                                     ads,
                                     a_feat_infos,
                                     valid_rfeats,
                                     r_feat_infos,
                                     valid_list,
                                     cfg["feat"]["u_enc"],
                                     cfg["feat"]["a_enc"],
                                     cfg["feat"]["r_enc"],
                                     reg=reg,
                                     pos_weight=cfg["train"]["pos_weight"],
                                     has_label=True)

        dataset = train_dataset

    else:
        test_dataset = DatasetYouth(users,
                                    u_feat_infos,
                                    ads,
                                    a_feat_infos,
                                    test_rfeats,
                                    r_feat_infos,
                                    test_list,
                                    cfg["feat"]["u_enc"],
                                    cfg["feat"]["a_enc"],
                                    cfg["feat"]["r_enc"],
                                    reg=reg,
                                    has_label=False)

        dataset = test_dataset

    logging.info("shuffle: {}".format(
        False if cfg["train"]["use_radio_sampler"] else True))

    # set up model
    emedding_cfgs = {}
    emedding_cfgs.update(cfg["feat"]["u_embed_cfg"])
    emedding_cfgs.update(cfg["feat"]["a_embed_cfg"])

    loss_cfg = cfg["loss"]

    # create model
    model = eval(cfg["model_name"])(
        n_out=1,
        u_embedding_feat_infos=dataset.embedding_u_feat_infos,
        u_one_hot_feat_infos=dataset.one_hot_u_feat_infos,
        a_embedding_feat_infos=dataset.embedding_a_feat_infos,
        a_one_hot_feat_infos=dataset.one_hot_a_feat_infos,
        r_embedding_feat_infos=dataset.embedding_r_feat_infos,
        embedding_cfgs=emedding_cfgs,
        loss_cfg=loss_cfg,
    )

    # model = DataParallel(model,device_ids=cfg["gpus"])
    # logging.info("Using model {}.".format(cfg["model_name"]))

    ## optmizers
    # todo lr,weight decay
    optimizer = Adam(model.get_train_policy(),
                     lr=cfg["optim"]["lr"],
                     weight_decay=cfg["optim"]["weight_decay"],
                     amsgrad=True)
    #optimizer = optim.SGD(model.parameters(), lr = 0.005, momentum=0.9,weight_decay=cfg["optim"]["weight_decay"])
    logging.info("Using optimizer {}.".format(optimizer))

    if cfg["train"]["resume"] or args.test:
        checkpoint_file = cfg["resume_fp"]
        state = load_checkpoint(checkpoint_file)
        logging.info("Load checkpoint file {}.".format(checkpoint_file))
        st_epoch = state["cur_epoch"] + 1
        logging.info("Start from {}th epoch.".format(st_epoch))
        model.load_state_dict(state["model_state"])
        optimizer.load_state_dict(state["optimizer_state"])
    else:
        st_epoch = 1
    ed_epoch = cfg["train"]["ed_epoch"]

    # move tensor to gpu and wrap tensor with Variable
    to_gpu_variable = dataset.get_to_gpu_variable_func()

    if args.extract_weight:
        model = model.module
        path = os.path.join(cfg["output_path"], "weight")
        os.makedirs(path, exist_ok=True)
        u_embedder = model.u_embedder
        u_embedder.save_weight(path)
        a_embedder = model.a_embedder
        a_embedder.save_weight(path)
        exit(0)

    def evaluate(output, label, label_weights):
        '''
        Note the input to this function should be converted to data first.
        :param output:
        :param label_weights:
        :param target:
        :return:
        '''
        output = output.view(-1)
        label = label.view(-1).byte()
        #print(output[0:100])
        #print(label[0:100])
        scores = torch.sigmoid(output)
        output = scores > 0.1
        #print(output)
        # print(label.float().sum())
        tp = ((output == label) * label).float().sum()
        fp = ((output != label) * output).float().sum()
        fn = ((output != label) * (1 - output)).float().sum()
        tn = ((output == label) * (1 - label)).float().sum()
        return tp, fp, fn, tn, scores.cpu()

    def valid(cur_train_epoch, phase="valid", extract_features=False):
        '''

        :param cur_train_epoch:
        :param phase: "valid" or "test"
        :return:
        '''
        assert phase in ["valid", "test"]

        results = []
        valid_detail_meters = {
            "loss": SumMeter(),
            "model_loss": SumMeter(),
            "tp": SumMeter(),
            "fn": SumMeter(),
            "fp": SumMeter(),
            "tn": SumMeter(),
            "batch_time": AverageMeter(),
            "io_time": AverageMeter(),
        }

        if phase == "valid":
            logging.info("Valid data.")
            dataset = valid_dataset
        else:
            logging.info("Test data.")
            dataset = test_dataset

        model.eval()
        logging.info("Set network to eval model")

        if extract_features:
            features = np.zeros(shape=(dataset.original_len,
                                       model.n_output_feat),
                                dtype=np.float32)
            features_ctr = 0

        batch_idx = 0

        # chunked here
        chunk_size = 200
        n_chunk = (dataset.original_len +
                   (cfg[phase]["batch_size"] * chunk_size) -
                   1) // (cfg[phase]["batch_size"] * chunk_size)
        n_batch = (dataset.original_len + cfg[phase]["batch_size"] -
                   1) // cfg[phase]["batch_size"]

        for chunk_idx in range(n_chunk):
            s = chunk_idx * cfg[phase]["batch_size"] * chunk_size
            e = (chunk_idx + 1) * cfg[phase]["batch_size"] * chunk_size

            dataloader = DataLoader(
                dataset.slice(s, e),
                batch_size=cfg[phase]["batch_size"],
                shuffle=False,
                num_workers=cfg[phase]["n_worker"],
                collate_fn=dataset.get_collate_func(),
                pin_memory=True,
                drop_last=False,
            )

            batch_time_s = time.time()
            for samples in dataloader:
                batch_idx = batch_idx + 1
                cur_batch = batch_idx
                valid_detail_meters["io_time"].update(time.time() -
                                                      batch_time_s)

                # move to gpu
                samples = to_gpu_variable(samples, volatile=True)

                # forward
                loss, output, model_loss, reg_loss, d = model(samples)

                if phase == "valid":

                    # evaluate metrics
                    valid_detail_meters["loss"].update(
                        loss.data[0] * samples["size"], samples["size"])
                    valid_detail_meters["model_loss"].update(
                        model_loss.data[0] * samples["size"], samples["size"])
                    tp, fp, fn, tn, scores = evaluate(
                        output.data, samples["labels"].data,
                        samples["label_weights"].data)
                    #print(tp,fn,fp,tn)
                    valid_detail_meters["tp"].update(tp, samples["size"])
                    valid_detail_meters["fp"].update(fp, samples["size"])
                    valid_detail_meters["fn"].update(fn, samples["size"])
                    valid_detail_meters["tn"].update(tn, samples["size"])
                    # the large the better
                    tp_rate = valid_detail_meters["tp"].sum / (
                        valid_detail_meters["tp"].sum +
                        valid_detail_meters["fn"].sum + 1e-20)
                    # the smaller the better
                    fp_rate = valid_detail_meters["fp"].sum / (
                        valid_detail_meters["fp"].sum +
                        valid_detail_meters["tn"].sum + 1e-20)
                    valid_detail_meters["batch_time"].update(time.time() -
                                                             batch_time_s)
                    batch_time_s = time.time()
                else:
                    scores = torch.sigmoid(output.data)
                    valid_detail_meters["batch_time"].update(time.time() -
                                                             batch_time_s)
                    batch_time_s = time.time()

                # collect results
                uids = samples["uids"]
                aids = samples["aids"]
                results.extend(zip(aids, uids, scores))

                # collect features
                if extract_features:
                    bs = samples["size"]
                    features[features_ctr:features_ctr +
                             bs, :] = d.data.cpu().numpy()
                    features_ctr += bs

                # log results
                if phase == "valid":
                    if cur_batch % cfg["valid"]["logging_freq"] == 0:
                        logging.info(
                            "Valid Batch [{cur_batch}/{ed_batch}] "
                            "loss: {loss} "
                            "model_loss: {model_loss} "
                            "tp: {tp} fn: {fn} fp: {fp} tn: {tn} "
                            "tp_rate: {tp_rate} fp_rate: {fp_rate} "
                            "io time: {io_time}s batch time {batch_time}s".
                            format(
                                cur_batch=cur_batch,
                                ed_batch=n_batch,
                                loss=valid_detail_meters["loss"].mean,
                                model_loss=valid_detail_meters["model_loss"].
                                mean,
                                tp=valid_detail_meters["tp"].sum,
                                fn=valid_detail_meters["fn"].sum,
                                fp=valid_detail_meters["fp"].sum,
                                tn=valid_detail_meters["tn"].sum,
                                tp_rate=tp_rate,
                                fp_rate=fp_rate,
                                io_time=valid_detail_meters["io_time"].mean,
                                batch_time=valid_detail_meters["batch_time"].
                                mean,
                            ))
                else:
                    if cur_batch % cfg["test"]["logging_freq"] == 0:
                        logging.info(
                            "Test Batch [{cur_batch}/{ed_batch}] "
                            "io time: {io_time}s batch time {batch_time}s".
                            format(
                                cur_batch=cur_batch,
                                ed_batch=n_batch,
                                io_time=valid_detail_meters["io_time"].mean,
                                batch_time=valid_detail_meters["batch_time"].
                                mean,
                            ))

        if phase == "valid":
            logging.info("{phase} for {cur_train_epoch} train epoch "
                         "loss: {loss} "
                         "model_loss: {model_loss} "
                         "tp_rate: {tp_rate} fp_rate: {fp_rate} "
                         "io time: {io_time}s batch time {batch_time}s".format(
                             phase=phase,
                             cur_train_epoch=cur_train_epoch,
                             loss=valid_detail_meters["loss"].mean,
                             model_loss=valid_detail_meters["model_loss"].mean,
                             tp_rate=tp_rate,
                             fp_rate=fp_rate,
                             io_time=valid_detail_meters["io_time"].mean,
                             batch_time=valid_detail_meters["batch_time"].mean,
                         ))

            # write results to file
            res_fn = "{}_{}".format(cfg["valid_res_fp"], cur_train_epoch)
            with open(res_fn, 'w') as f:
                f.write("aid,uid,score\n")
                for res in results:
                    f.write("{},{},{:.8f}\n".format(res[0], res[1], res[2]))
            # evaluate results
            avg_auc, aucs = cal_avg_auc(res_fn, cfg["valid_fp"])
            logging.info("Valid for {cur_train_epoch} train epoch "
                         "average auc {avg_auc}".format(
                             cur_train_epoch=cur_train_epoch,
                             avg_auc=avg_auc,
                         ))
            logging.info("aucs: ")
            logging.info(pprint.pformat(aucs))

        else:
            logging.info(
                "Test for {} train epoch ends.".format(cur_train_epoch))

            res_fn = "{}_{}".format(cfg["test_res_fp"], cur_train_epoch)
            with open(res_fn, 'w') as f:
                f.write("aid,uid,score\n")
                for res in results:
                    f.write("{},{},{:.8f}\n".format(res[0], res[1], res[2]))

        # extract features
        if extract_features:
            import pickle as pkl
            with open(cfg["extracted_features_fp"], "wb") as f:
                pkl.dump(features, f, protocol=pkl.HIGHEST_PROTOCOL)

    model.cuda()
    logging.info("Move network to gpu.")

    if args.test:
        valid(st_epoch - 1,
              phase="test",
              extract_features=args.extract_features)
        exit(0)
    elif cfg["valid"]["init_valid"]:
        valid(st_epoch - 1)
        model.train()
        logging.info("Set network to train model.")

    # train: main loop

    model.train()
    logging.info("Set network to train model.")

    # original_lambda = cfg["reg"]["lambda"]
    total_n_batch = 0
    warnings.warn("total_n_batch always start at 0...")

    for cur_epoch in range(st_epoch, ed_epoch + 1):

        # meters
        k = cfg["train"]["logging_freq"]
        detail_meters = {
            "loss": RunningValue(k),
            "epoch_loss": SumMeter(),
            "model_loss": RunningValue(k),
            "epoch_model_loss": SumMeter(),
            "tp": RunningValue(k),
            "fn": RunningValue(k),
            "fp": RunningValue(k),
            "tn": RunningValue(k),
            "auc": AUCMeter(),
            "batch_time": AverageMeter(),
            "io_time": AverageMeter(),
        }

        # adjust lr
        adjust_learning_rate(cfg["optim"]["lr"],
                             optimizer,
                             cur_epoch,
                             cfg["train"]["lr_steps"],
                             lr_decay=cfg["train"]["lr_decay"])

        # dynamic adjust cfg["reg"]["lambda"]
        # decay = 1/(0.5 ** (sum(cur_epoch > np.array(cfg["train"]["lr_steps"]))))
        # cfg["reg"]["lambda"] = original_lambda * decay
        # print("using dynamic regularizer, {}".format(cfg["reg"]["lambda"]))

        train_dataset.shuffle()

        batch_idx = -1

        # chunked here because of memory issue. we always create new DataLoader after several batches.
        chunk_size = 200
        n_chunk = (train_dataset.original_len +
                   (cfg["train"]["batch_size"] * chunk_size) -
                   1) // (cfg["train"]["batch_size"] * chunk_size)
        n_batch = (train_dataset.original_len + cfg["train"]["batch_size"] -
                   1) // cfg["train"]["batch_size"]

        for chunk_idx in range(n_chunk):
            s = chunk_idx * cfg["train"]["batch_size"] * chunk_size
            e = (chunk_idx + 1) * cfg["train"]["batch_size"] * chunk_size

            train_dataloader = DataLoader(
                train_dataset.slice(s, e),
                batch_size=cfg["train"]["batch_size"],
                shuffle=False if cfg["train"]["use_radio_sampler"] else True,
                num_workers=cfg["train"]["n_worker"],
                collate_fn=train_dataset.get_collate_func(),
                sampler=radio_sampler
                if cfg["train"]["use_radio_sampler"] else None,
                pin_memory=True,
                drop_last=True,
            )

            batch_time_s = time.time()
            for samples in train_dataloader:
                total_n_batch += 1
                batch_idx = batch_idx + 1
                detail_meters["io_time"].update(time.time() - batch_time_s)

                # move to gpu
                samples = to_gpu_variable(samples)

                # forward
                loss, output, model_loss, reg_loss, d = model(samples)

                #print("reg_loss",reg_loss)

                # clear grads
                optimizer.zero_grad()

                # backward
                loss.backward()

                # This is a little useful
                warnings.warn("Using gradients clipping")
                clip_grad_norm(model.parameters(), max_norm=5)

                # update weights
                optimizer.step()

                # evaluate metrics
                detail_meters["loss"].update(loss.data[0])
                detail_meters["epoch_loss"].update(loss.data[0])
                detail_meters["model_loss"].update(model_loss.data[0])
                detail_meters["epoch_model_loss"].update(model_loss.data[0])
                tp, fp, fn, tn, scores = evaluate(
                    output.data, samples["labels"].data,
                    samples["label_weights"].data)
                #print(tp,fn,fp,tn)
                detail_meters["tp"].update(tp)
                detail_meters["fp"].update(fp)
                detail_meters["fn"].update(fn)
                detail_meters["tn"].update(tn)
                # the large the better
                tp_rate = detail_meters["tp"].sum / (
                    detail_meters["tp"].sum + detail_meters["fn"].sum + 1e-20)
                # the smaller the better
                fp_rate = detail_meters["fp"].sum / (
                    detail_meters["fp"].sum + detail_meters["tn"].sum + 1e-20)
                detail_meters["batch_time"].update(time.time() - batch_time_s)

                # collect results
                uids = samples["uids"]
                aids = samples["aids"]
                preds = zip(aids, uids, scores)
                gts = zip(aids, uids, samples["labels"].cpu().data)
                detail_meters["auc"].update(preds, gts)

                batch_time_s = time.time()

                # log results
                if (batch_idx + 1) % cfg["train"]["logging_freq"] == 0:
                    logging.info(
                        "Train Batch [{cur_batch}/{ed_batch}] "
                        "loss: {loss} "
                        "model_loss: {model_loss} "
                        "auc: {auc} "
                        "tp: {tp} fn: {fn} fp: {fp} tn: {tn} "
                        "tp_rate: {tp_rate} fp_rate: {fp_rate} "
                        "io time: {io_time}s batch time {batch_time}s".format(
                            cur_batch=batch_idx + 1,
                            ed_batch=n_batch,
                            loss=detail_meters["loss"].mean,
                            model_loss=detail_meters["model_loss"].mean,
                            tp=detail_meters["tp"].sum,
                            fn=detail_meters["fn"].sum,
                            fp=detail_meters["fp"].sum,
                            tn=detail_meters["tn"].sum,
                            auc=detail_meters["auc"].auc,
                            tp_rate=tp_rate,
                            fp_rate=fp_rate,
                            io_time=detail_meters["io_time"].mean,
                            batch_time=detail_meters["batch_time"].mean,
                        ))
                    detail_meters["auc"].reset()

                if total_n_batch % cfg["train"][
                        "backup_freq_batch"] == 0 and total_n_batch >= cfg[
                            "train"]["start_backup_batch"]:
                    state_to_save = {
                        "cur_epoch": cur_epoch,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                    }
                    checkpoint_file = os.path.join(
                        cfg["output_path"],
                        "epoch_{}_tbatch_{}.checkpoint".format(
                            cur_epoch, total_n_batch))
                    save_checkpoint(state_to_save, checkpoint_file)
                    logging.info(
                        "Save checkpoint to {}.".format(checkpoint_file))

                if total_n_batch % cfg["train"][
                        "valid_freq_batch"] == 0 and total_n_batch >= cfg[
                            "train"]["start_valid_batch"]:
                    valid(cur_epoch)
                    model.train()
                    logging.info("Set network to train model.")

        logging.info("Train Epoch [{cur_epoch}] "
                     "loss: {loss} "
                     "model_loss: {model_loss} ".format(
                         cur_epoch=cur_epoch,
                         loss=detail_meters["epoch_loss"].mean,
                         model_loss=detail_meters["epoch_model_loss"].mean,
                     ))

        # back up
        if cur_epoch % cfg["train"][
                "backup_freq_epoch"] == 0 and cur_epoch >= cfg["train"][
                    "start_backup_epoch"]:
            state_to_save = {
                "cur_epoch": cur_epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
            }
            checkpoint_file = os.path.join(
                cfg["output_path"], "epoch_{}.checkpoint".format(cur_epoch))
            save_checkpoint(state_to_save, checkpoint_file)
            logging.info("Save checkpoint to {}.".format(checkpoint_file))

        # valid on valid dataset
        if cur_epoch % cfg["train"][
                "valid_freq_epoch"] == 0 and cur_epoch >= cfg["train"][
                    "start_valid_epoch"]:
            valid(cur_epoch)
            model.train()
            logging.info("Set network to train model.")
Exemplo n.º 21
0
class DataGenerator(object):
    def __init__(self,
                 MovementModule='default',
                 EnvModule='default',
                 TextModule='default',
                 WindowHeight=720,
                 WindowWidth=1080,
                 MaxTextNum=15,
                 DataStoragePath='../../../GeneratedData/DataFraction_1',
                 camera_anchor_filepath='./camera_anchors/urbancity.txt',
                 EnvName='',
                 anchor_freq=10,
                 max_emissive=5,
                 FontSize=[8, 16],
                 use_real_img=0.1,
                 is_debug=True,
                 languages=["Latin"],
                 HighResFactor=2.0,
                 UnrealProjectName="./",
                 **kwargs):
        self.client = WrappedClient(UnrealCVClient, DataStoragePath,
                                    HighResFactor, UnrealProjectName)
        self.DataStoragePath = DataStoragePath
        self.UnrealProjectName = UnrealProjectName
        self.WindowHeight = WindowHeight
        self.WindowWidth = WindowWidth
        self.MaxTextNum = MaxTextNum
        self.is_debug = is_debug
        self.camera_anchor_filepath = camera_anchor_filepath
        self.anchor_freq = anchor_freq
        self.HighResFactor = HighResFactor

        self.RootPath = opa(DataStoragePath)
        while os.path.isdir(self.RootPath):
            root_path, count = self.RootPath.split('_')
            self.RootPath = root_path + '_' + str(int(count) + 1)
        print(f"Data will be saved to: {self.RootPath}")
        self.LabelPath = osp.join(self.RootPath, 'Label.json')
        self.DataLabel = None
        self.ImgFolder = osp.join(self.RootPath, 'imgs')
        self.LabelFolder = osp.join(self.RootPath, 'labels')
        self.WordFolder = osp.join(self.RootPath, 'WordCrops')
        self.DataCount = 0
        self.isConnected = False
        self.SaveFreq = 100

        # step 1
        self._InitializeDataStorage()
        # step 2
        if len(EnvName) > 0:
            StartEngine(EnvName)
        self._ConnectToGame()
        # step 3 set resolution & rotation
        self.client.setres(self.WindowWidth, self.WindowHeight)
        # self.client.setCameraRotation(0, 0, 0)
        self.EnvDepth = kwargs.get('EnvDepth', 100)
        # load modules
        self.Wanderer = CameraSet[MovementModule](
            client=self.client,
            camera_anchor_filepath=self.camera_anchor_filepath,
            anchor_freq=self.anchor_freq)
        self.EnvRenderer = EnvSet[EnvModule](client=self.client)
        self.TextPlacer = TextPlacement[TextModule](
            client=self.client,
            MaxTextCount=self.MaxTextNum,
            ContentPath=osp.join(self.RootPath, 'WordCrops'),
            max_emissive=max_emissive,
            FontSize=FontSize,
            is_debug=is_debug,
            use_real_img=use_real_img,
            languages=languages,
            HighResFactor=HighResFactor)

        # initializer meters
        self.camera_meter = AverageMeter()
        self.env_meter = AverageMeter()
        self.text_meter = AverageMeter()
        self.retrieve_label_meter = AverageMeter()
        self.save_label_meter = AverageMeter()
        self.save_meter = AverageMeter()

        self._cleanup()

    def __del__(self):
        if self.client.isconnected():
            self.client.QuitGame()
            self.client.disconnect()
            self._cleanup()
            # os.system('~/cache_light.png')

    def _cleanup(self):
        os.system(
            f'rm ../../../PackagedEnvironment/{self.UnrealProjectName}/Demo/Saved/Screenshots/LinuxNoEditor/*png'
        )
        #os.system(f'rm ../../../PackagedEnvironment/{self.UnrealProjectName}/Demo/Saved/Logs/*')

    def _InitializeDataStorage(self):
        os.makedirs(self.ImgFolder, exist_ok=True)
        os.makedirs(self.LabelFolder, exist_ok=True)
        os.makedirs(self.WordFolder, exist_ok=True)
        self.DataCount = 0
        self.DataLabel = []
        os.system(f'cp vis.py {self.RootPath}/')

    def _ConnectToGame(self):
        # wait and connect
        sleepTime = 1.0
        while True:
            self.client.connect()
            self.isConnected = self.client.isconnected()
            if self.isConnected:
                break
            else:
                if sleepTime > 120:
                    break
                time.sleep(sleepTime)
                sleepTime *= 2
        if not self.isConnected:
            print('Failed to connect to UnrealCV server.')
            sys.exit(-1)

    def _GenerateOneImageInstance(self,
                                  step_count,
                                  force_change_camera_anchor=False):
        # step 1: move around
        time_stamp = time.time()
        if not self.is_debug:
            self.Wanderer.step(
                height=self.WindowHeight,
                width=self.WindowWidth,
                step=step_count,
                force_change_camera_anchor=force_change_camera_anchor)
        time_stamp = self.camera_meter.update(time.time() - time_stamp)

        # step 2: render env
        self.EnvRenderer.step()
        time_stamp = self.env_meter.update(time.time() - time_stamp)

        # step 3: place text
        self.TextPlacer.PutTextStep()
        time_stamp = self.text_meter.update(time.time() - time_stamp)
        if self.is_debug:
            print('Text Loaded, ready to retrieve data')

        # step 4: retrieve data
        img_path, Texts, WordBoxes, CharBoxes, TextNum = self.TextPlacer.RetrieveDataStep(
            osp.join(self.ImgFolder, f'{self.DataCount}.jpg'))
        time_stamp = self.retrieve_label_meter.update(time.time() - time_stamp)

        force_change_camera_anchor = TextNum == 0
        DataLabel = {
            'imgfile': f'imgs/{self.DataCount}.jpg',
            'bbox': WordBoxes,
            'cbox': CharBoxes,
            'text': Texts,
            'is_difficult': [0 for _ in range(len(WordBoxes))]
        }
        json.dump(
            DataLabel,
            open(osp.join(self.LabelFolder,
                          str(self.DataCount) + '.json'), 'w'))
        time_stamp = self.save_label_meter.update(time.time() - time_stamp)
        if self.is_debug:
            print('Finished waiting, ready to save img')
        self.client.SaveImg(img_path)
        time_stamp = self.save_meter.update(time.time() - time_stamp)
        self.DataCount += 1

        if self.is_debug:
            # step 5: visualize
            ShowImgAndAnnotation(img_path, Texts, WordBoxes, CharBoxes)
        time_stamp = time.time()
        return {'force_change_camera_anchor': force_change_camera_anchor}

    def StartGeneration(self, IterationNum=10000, sleep_time=0, sleep_freq=1):
        status = {'force_change_camera_anchor': False}
        for Count in range(IterationNum):
            status = self._GenerateOneImageInstance(Count, **status)
            if Count % self.anchor_freq == 0:
                print(f"{Count} images created. Timing:")
                print(f' ----- camera:               {self.camera_meter}')
                print(f' ----- env:                  {self.env_meter}')
                print(f' ----- text:                 {self.text_meter}')
                print(
                    f' ----- retrieve label:       {self.retrieve_label_meter}'
                )
                print(f' ----- save label:           {self.save_label_meter}')
                print(f' ----- retrieve image:       {self.save_meter}')
Exemplo n.º 22
0
def train(args):

    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    tfboard_writer = SummaryWriter()
    logname = '{}'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H:%M'))
    logger = Logger(args.save_path, logname)
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))

    # Data Augmentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = [
        transforms.AugTransBbox(args.transbbox_prob, args.transbbox_percent)
    ]
    train_transform += [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    #train_transform += [transforms.AugHorizontalFlip(args.flip_prob)]
    #train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
    #train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugGaussianBlur(args.gaussianblur_prob,
                                   args.gaussianblur_kernel_size,
                                   args.gaussianblur_sigma)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])

    # Training datasets
    train_data = GeneralDataset(args.num_pts, train_transform,
                                args.train_lists)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # Evaluation Dataloader
    eval_loaders = []

    for eval_ilist in args.eval_lists:
        eval_idata = GeneralDataset(args.num_pts, eval_transform, eval_ilist)
        eval_iloader = torch.utils.data.DataLoader(eval_idata,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        eval_loaders.append(eval_iloader)

    net = Model(args.num_pts)

    logger.log("=> network :\n {}".format(net))
    logger.log('arguments : {:}'.format(args))

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       net.parameters()),
                                lr=args.LR,
                                momentum=args.momentum,
                                weight_decay=args.decay,
                                nesterov=args.nesterov)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.schedule,
                                                     gamma=args.gamma)

    criterion = wing_loss(args)
    # criterion = torch.nn.MSELoss(reduce=True)

    net = net.cuda()
    criterion = criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint[
            'epoch'], 'Last-Info is not right {:} vs {:}'.format(
                last_info, checkpoint['epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
            logger.last_info(), checkpoint['epoch']))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        scheduler.step()

        net.train()

        # train
        img_prediction = []
        img_target = []
        train_losses = AverageMeter()
        for i, (inputs, target) in enumerate(train_loader):

            target = target.squeeze(1)
            inputs = inputs.cuda()
            target = target.cuda()
            #print(inputs.size())
            #ssert 1==0

            prediction = net(inputs)

            loss = criterion(prediction, target)
            train_losses.update(loss.item(), inputs.size(0))

            prediction = prediction.detach().to(torch.device('cpu')).numpy()
            target = target.detach().to(torch.device('cpu')).numpy()

            for idx in range(inputs.size()[0]):
                img_prediction.append(prediction[idx, :])
                img_target.append(target[idx, :])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.print_freq == 0 or i + 1 == len(train_loader):
                logger.log(
                    '[train Info]: [epoch-{}-{}][{:04d}/{:04d}][Loss:{:.2f}]'.
                    format(epoch, args.epochs, i, len(train_loader),
                           loss.item()))

        train_nme = compute_nme(args.num_pts, img_prediction, img_target)
        logger.log('epoch {:02d} completed!'.format(epoch))
        logger.log(
            '[train Info]: [epoch-{}-{}][Avg Loss:{:.6f}][NME:{:.2f}]'.format(
                epoch, args.epochs, train_losses.avg, train_nme * 100))
        tfboard_writer.add_scalar('Average Loss', train_losses.avg, epoch)
        tfboard_writer.add_scalar('NME', train_nme * 100,
                                  epoch)  # traing data nme

        # save checkpoint
        filename = 'epoch-{}-{}.pth'.format(epoch, args.epochs)
        save_path = logger.path('model') / filename
        torch.save(
            {
                'epoch': epoch,
                'args': deepcopy(args),
                'state_dict': net.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            logger.path('model') / filename)
        logger.log('save checkpoint into {}'.format(filename))
        last_info = torch.save({
            'epoch': epoch,
            'last_checkpoint': save_path
        }, logger.last_info())

        # eval
        logger.log('Basic-Eval-All evaluates {} dataset'.format(
            len(eval_loaders)))

        for i, loader in enumerate(eval_loaders):

            eval_losses = AverageMeter()
            eval_prediction = []
            eval_target = []
            with torch.no_grad():
                net.eval()
                for i_batch, (inputs, target) in enumerate(loader):

                    target = target.squeeze(1)
                    inputs = inputs.cuda()
                    target = target.cuda()
                    prediction = net(inputs)
                    loss = criterion(prediction, target)
                    eval_losses.update(loss.item(), inputs.size(0))

                    prediction = prediction.detach().to(
                        torch.device('cpu')).numpy()
                    target = target.detach().to(torch.device('cpu')).numpy()

                    for idx in range(inputs.size()[0]):
                        eval_prediction.append(prediction[idx, :])
                        eval_target.append(target[idx, :])
                    if i_batch % args.print_freq == 0 or i + 1 == len(loader):
                        logger.log(
                            '[Eval Info]: [epoch-{}-{}][{:04d}/{:04d}][Loss:{:.2f}]'
                            .format(epoch, args.epochs, i, len(loader),
                                    loss.item()))

            eval_nme = compute_nme(args.num_pts, eval_prediction, eval_target)
            logger.log(
                '[Eval Info]: [evaluate the {}/{}-th dataset][epoch-{}-{}][Avg Loss:{:.6f}][NME:{:.2f}]'
                .format(i, len(eval_loaders), epoch, args.epochs,
                        eval_losses.avg, eval_nme * 100))
            tfboard_writer.add_scalar('eval_nme/{}'.format(i), eval_nme * 100,
                                      epoch)

    logger.close()
Exemplo n.º 23
0
    def process(self):
        acc = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()
        log_file = os.path.join(self.data_folder, 'test.log')
        logger = Logger('test', log_file)
        # switch to evaluate mode
        self.model.eval()

        start_time = time.clock()
        print("Begin testing")
        for i, (images, labels) in enumerate(self.test_loader):
            if check_gpu() > 0:
                images = images.cuda(async=True)
                labels = labels.cuda(async=True)

            image_var = torch.autograd.Variable(images)
            label_var = torch.autograd.Variable(labels)

            # compute y_pred
            y_pred = self.model(image_var)
            loss = self.criterion(y_pred, label_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(y_pred.data, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            acc.update(prec1.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

            if i % self.print_freq == 0:
                print('TestVal: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    i, len(self.test_loader), loss=losses, top1=top1, top5=top5))

        print(
            ' * Accuracy {acc.avg:.3f}  Acc@5 {top5.avg:.3f} Loss {loss.avg:.3f}'.format(acc=acc, top5=top5,
                                                                                         loss=losses))

        end_time = time.clock()
        print("Total testing time %.2gs" % (end_time - start_time))
        logger.info("Total testing time %.2gs" % (end_time - start_time))
        logger.info(
            ' * Accuracy {acc.avg:.3f}  Acc@5 {top5.avg:.3f} Loss {loss.avg:.3f}'.format(acc=acc, top5=top5,
                                                                                         loss=losses))
Exemplo n.º 24
0
def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        output = model(input)
        loss = criterion(output, target)

        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(trn_loader))

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            # zero_grad() and converting fp16/fp32 is handled in optimizer
            loss.backward()
            optimizer.step(wait_for_finish=should_print)
            loss = loss / args.loss_scale
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()

        if args.local_rank == 0 and should_print:
            corr1, corr5 = correct(output.data, target, topk=(1, 5))
            reduced_loss, batch_total = to_python_float(
                loss.data), to_python_float(input.size(0))
            if args.distributed:  # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
                validate_tensor[0] = batch_total
                validate_tensor[1] = reduced_loss
                validate_tensor[2] = corr1
                validate_tensor[3] = corr5
                batch_total, reduced_loss, corr1, corr5 = bps.push_pull(
                    validate_tensor, average=False, name="validation_tensor")
                batch_total = batch_total.cpu().numpy()
                reduced_loss = reduced_loss.cpu().numpy()
                corr1 = corr1.cpu().numpy()
                corr5 = corr5.cpu().numpy()
                reduced_loss = reduced_loss / bps.size()

            top1acc = to_python_float(corr1) * (100.0 / batch_total)
            top5acc = to_python_float(corr5) * (100.0 / batch_total)

            losses.update(reduced_loss, batch_total)
            top1.update(top1acc, batch_total)
            top5.update(top5acc, batch_total)
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val,
                             input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)

            output = (
                f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

            tb.update_step_count(batch_total)