Exemplo n.º 1
0
def train(data_loader,
          model_pos,
          criterion,
          optimizer,
          device,
          lr_init,
          lr_now,
          step,
          decay,
          gamma,
          max_norm=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_loss_3d_pos = AverageMeter()

    # Switch to train mode
    torch.set_grad_enabled(True)
    model_pos.train()
    end = time.time()

    bar = Bar('Train', max=len(data_loader))
    for i, (targets_3d, inputs_2d, _) in enumerate(data_loader):
        # Measure data loading time
        data_time.update(time.time() - end)
        num_poses = targets_3d.size(0)

        step += 1
        if step % decay == 0 or step == 1:
            lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)

        targets_3d, inputs_2d = targets_3d.to(device), inputs_2d.to(device)
        #outputs_3d = model_pos(inputs_2d, edge_index=dataset.skeleton().joints_group())
        outputs_3d = model_pos(inputs_2d)
        #print(inputs_2d.shape)
        #print('-----------------------------------')
        #print(outputs_3d.shape)

        optimizer.zero_grad()
        loss_3d_pos = criterion(outputs_3d, targets_3d)
        loss_3d_pos.backward()
        if max_norm:
            nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
        optimizer.step()

        epoch_loss_3d_pos.update(loss_3d_pos.item(), num_poses)

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                     '| Loss: {loss: .4f}' \
            .format(batch=i + 1, size=len(data_loader), data=data_time.val, bt=batch_time.avg,
                    ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss_3d_pos.avg)
        bar.next()

    bar.finish()
    return epoch_loss_3d_pos.avg, lr_now, step
Exemplo n.º 2
0
def train(loss_last, data_loader, model_pos, criterion, optimizer, device, lr_init, lr_now, step, decay, gamma,
          max_norm=True, loss_3d=False, earlyend=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_loss = AverageMeter()

    # Switch to train mode
    torch.set_grad_enabled(True)
    model_pos.train()
    end = time.time()
    dataperepoch_limit = 1e10
    dataperepoch_count = 0
    if earlyend:
        dataperepoch_limit = 1

    bar = Bar('Train', max=len(data_loader))
    for i, (targets_score, inputs_2d, data_dict) in enumerate(data_loader):
        # Measure data loading time
        data_time.update(time.time() - end)
        num_poses = targets_score.size(0)
        dataperepoch_count += num_poses
        if dataperepoch_count >= dataperepoch_limit:
            break
        step += 1
        if step % decay == 0 or step == 1:
            lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)

        targets_score, inputs_2d = targets_score.to(device), inputs_2d.to(device)
        outputs_score = model_pos(inputs_2d)
        optimizer.zero_grad()

        ### 3d loss
        if loss_last < 0.3 or loss_3d:
            loss_3d = True
            for key in data_dict.keys():
                if isinstance(data_dict[key], torch.Tensor):
                    data_dict[key] = data_dict[key].to(device)
            output_3d, targets_3d = triangulation_acc(outputs_score, data_dict=data_dict, all_metric=False)
            loss_ = mpjpe(output_3d['ltr_after'], targets_3d)
        else:
            loss_ = criterion(outputs_score, targets_score)

        loss_.backward()

        if max_norm:
            nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
        optimizer.step()

        epoch_loss.update(loss_.item(), num_poses)

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                     '| Loss: {loss: .4f}' \
            .format(batch=i + 1, size=len(data_loader), data=data_time.val, bt=batch_time.avg,
                    ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss.avg)
        bar.next()

    bar.finish()
    return epoch_loss.avg, lr_now, step