Beispiel #1
0
def validate(val_loader, model_gen, device, out_dir_path):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    os.makedirs(out_dir_path)

    # switch to evaluate mode
    torch.set_grad_enabled(False)
    model_gen.eval()
    end = time.time()

    bar = Bar('Eval ', max=len(val_loader))
    for i, (img_src, norm_src, img_dst, norm_dst) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        x_src, x_dst, n_src, n_dst = img_src.to(device), img_dst.to(device), norm_src.to(device), norm_dst.to(device)
        x_fake, _ = model_gen(x_src, n_src, n_dst)
        num_rows, x_out = cat_triplet(x_src, x_fake, x_dst)
        save_image(x_out, os.path.join(out_dir_path, 'eval_batch_{:04d}.jpg'.format(i + 1)), normalize=True, nrow=num_rows)

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format(
            batch=i + 1,
            size=len(val_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td
        )

        bar.next()

    bar.finish()
    return
Beispiel #2
0
def train(data_loader, model, criterion, optimizer, scheduler, epoch):
    running_loss = AverageMeter()

    # Switch to train mode
    torch.set_grad_enabled(True)
    model.train()

    for i, (state, target) in enumerate(data_loader):

        state = state.float()
        target = target.float()
        num_states = state.shape[0]

        pred, M, c, g = model(state)  # pred.shape: (batch_size, 2)

        loss = criterion(pred, target) / 1000000

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss.update(loss.item(), num_states)

    print('[{}] loss: {:.3f}, lr: {:.5f}'.format(epoch + 1, running_loss.avg,
                                                 scheduler.get_last_lr()[0]))

    scheduler.step()
Beispiel #3
0
def test(test_loader, model_gen, device, out_dir_path):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    if not os.path.exists(out_dir_path):
        os.makedirs(out_dir_path)
        print('Make output dir: {}'.format(out_dir_path))

    # switch to evaluate mode
    torch.set_grad_enabled(False)
    model_gen.eval()
    end = time.time()

    bar = Bar('Test ', max=len(test_loader))
    for i, (img_src, norm_src, img_t, norm_t) in enumerate(test_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        output = []
        for j in range(img_t.size(2)):
            x_src, n_src = img_src.to(device), norm_src.to(device)
            x_dst, n_dst = img_t[:, :, j, :, :].to(
                device), norm_t[:, :, j, :, :].to(device)
            x_fake, _ = model_gen(x_src, n_src, n_dst)

            _, x_out = cat_triplet(n_dst, x_fake, x_dst)
            output.append(torch_to_pil_image(x_out))

        # save videos
        dump_gif(
            output,
            os.path.join(out_dir_path, 'test_batch_{:04d}.gif'.format(i + 1)))

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format(
            batch=i + 1,
            size=len(test_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td)

        bar.next()

    bar.finish()
    return
Beispiel #4
0
def evaluate(data_loader, model, criterion):
    running_loss = AverageMeter()

    # Switch to eval mode
    model.eval()

    # visualizer
    viz = Visualizer()

    for i, (state, target) in enumerate(data_loader):

        state = state.float()
        target = target.float()  # state.shape: (batch_size, 2, 3)
        num_states = state.shape[0]

        pred, M, c, g = model(state)  # pred.shape: (batch_size, 2)

        loss = criterion(pred, target) / 1000000

        running_loss.update(loss.item(), num_states)

        # test
        q, qdot, qddot = split_states(state.numpy().squeeze())
        M_gt, c_gt, g_gt = generate_eom(q, qdot)

        M_pred = M.detach().numpy()
        c_pred = c.detach().numpy()
        g_pred = g.detach().numpy()

        u_pred, u_gt = pred.detach().numpy(), target.detach().numpy()

        viz.add_data(q, qdot, qddot, (u_pred, u_gt),
                     (M_pred @ qddot.reshape(2, ), M_gt @ qddot.reshape(2, )),
                     (c_pred, c_gt), (g_pred, g_gt))

    print('evaluate loss: {:.3f}'.format(running_loss.avg))

    viz.save_plot()
Beispiel #5
0
def validate(model, criterion, valset, iteration, batch_size, n_gpus,
             collate_fn, logger, text_logger, distributed_run, rank):
    """Handles all the validation scoring and printing"""
    model.eval()
    with torch.no_grad():
        val_sampler = DistributedSampler(valset) if distributed_run else None
        val_loader = DataLoader(valset,
                                sampler=val_sampler,
                                num_workers=1,
                                shuffle=True,
                                batch_size=batch_size,
                                pin_memory=False,
                                collate_fn=collate_fn)

        losses = AverageMeter('Loss', ':.4e')
        grad_norms = AverageMeter('GradNorm', ':.4e')
        progress = ProgressMeter(len(val_loader),
                                 losses,
                                 grad_norms,
                                 prefix="Test: ",
                                 logger=text_logger)

        for i, batch in enumerate(val_loader):
            x, y = model.parse_batch(batch)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            if distributed_run:
                reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_val_loss = loss.item()
            losses.update(reduced_val_loss, x[0].size(0))

    model.train()
    if rank == 0:
        progress.print(0)
        logger.log_validation(losses.avg, model, y, y_pred, iteration)
Beispiel #6
0
def valid(args, model, device, valid_loader, criterion):
    model.eval()
    val_losses = AverageMeter()
    val_top1 = AverageMeter()

    with torch.no_grad():
        for data, target in valid_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss = criterion(output, target)
            val_losses.update(val_loss.item(), data.size(0))
            prec1 = accuracy(output, target, topk=(1,))
            val_top1.update(prec1[0], data.size(0))

    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f} %\n'.format(
        val_losses.avg, val_top1.avg.item()))

    return val_losses, val_top1
Beispiel #7
0
def train(args, model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        losses.update(loss.item(), data.size(0))
        prec1 = accuracy(output, target, topk=(1,))
        top1.update(prec1[0], data.size(0))
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == (args.log_interval - 1):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), losses.avg))

    return losses, top1
Beispiel #8
0
def evaluate(data_loader, model_pos, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_loss_3d_pos = AverageMeter()
    epoch_loss_3d_pos_procrustes = AverageMeter()

    # Switch to evaluate mode
    torch.set_grad_enabled(False)
    model_pos.eval()
    end = time.time()

    bar = Bar('Eval ', max=len(data_loader))
    for i, (targets_3d, inputs_2d, _) in enumerate(data_loader):
        # Measure data loading time
        data_time.update(time.time() - end)
        num_poses = targets_3d.size(0)

        inputs_2d = inputs_2d.to(device)
        outputs_3d = model_pos(inputs_2d).cpu()
        outputs_3d[:, :, :] -= outputs_3d[:, :
                                          1, :]  # Zero-centre the root (hip)

        epoch_loss_3d_pos.update(
            mpjpe(outputs_3d, targets_3d).item() * 1000.0, num_poses)
        epoch_loss_3d_pos_procrustes.update(
            p_mpjpe(outputs_3d.numpy(), targets_3d.numpy()).item() * 1000.0,
            num_poses)

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                     '| MPJPE: {e1: .4f} | P-MPJPE: {e2: .4f}' \
            .format(batch=i + 1, size=len(data_loader), data=data_time.val, bt=batch_time.avg,
                    ttl=bar.elapsed_td, eta=bar.eta_td, e1=epoch_loss_3d_pos.avg, e2=epoch_loss_3d_pos_procrustes.avg)
        bar.next()

    bar.finish()
    return epoch_loss_3d_pos.avg, epoch_loss_3d_pos_procrustes.avg
Beispiel #9
0
def train(data_loader,
          model_pos,
          criterion,
          optimizer,
          device,
          lr_init,
          lr_now,
          step,
          decay,
          gamma,
          max_norm=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_loss_3d_pos = AverageMeter()

    # Switch to train mode
    torch.set_grad_enabled(True)
    model_pos.train()
    end = time.time()

    bar = Bar('Train', max=len(data_loader))
    for i, (targets_3d, inputs_2d) in enumerate(data_loader):
        # Measure data loading time
        data_time.update(time.time() - end)
        num_poses = targets_3d.size(0)

        step += 1
        if step % decay == 0 or step == 1:
            lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)

        targets_3d, inputs_2d = targets_3d.to(device), inputs_2d.to(device)
        outputs_3d = model_pos(inputs_2d)

        optimizer.zero_grad()
        loss_3d_pos = criterion(outputs_3d, targets_3d)
        loss_3d_pos.backward()
        if max_norm:
            nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
        optimizer.step()

        epoch_loss_3d_pos.update(loss_3d_pos.item(), num_poses)

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                     '| Loss: {loss: .4f}' \
            .format(batch=i + 1, size=len(data_loader), data=data_time.val, bt=batch_time.avg,
                    ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss_3d_pos.avg)
        bar.next()

    bar.finish()
    return epoch_loss_3d_pos.avg, lr_now, step
Beispiel #10
0
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
          rank, group_name, hparams):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    output_directory (string): directory to save checkpoints
    log_directory (string) directory to save tensorboard logs
    checkpoint_path(string): checkpoint path
    n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    if hparams.distributed_run:
        init_distributed(hparams, n_gpus, rank, group_name)

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    model = load_model(hparams)
    learning_rate = hparams.learning_rate
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=hparams.weight_decay)
    if hparams.fp16_run:
        optimizer = FP16_Optimizer(
            optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)

    criterion = Tacotron2Loss(hparams.mel_weight, hparams.gate_weight)

    logger, text_logger = prepare_directories_and_logger(
        output_directory, log_directory, rank, hparams)

    text_logger.info(hparams.__dict__)

    train_loader, valset, collate_fn = prepare_dataloaders(hparams)

    # Load checkpoint if one exists
    iteration = 0
    epoch_offset = 0
    if checkpoint_path:
        if warm_start:
            model = warm_start_model(checkpoint_path, model)
        else:
            model, optimizer, _learning_rate, iteration = load_checkpoint(
                checkpoint_path, model, optimizer)
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
            iteration += 1  # next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(train_loader)))

    model.train()
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, hparams.epochs):

        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        grad_norms = AverageMeter('GradNorm', ':.4e')
        progress = ProgressMeter(len(train_loader),
                                 batch_time,
                                 data_time,
                                 losses,
                                 grad_norms,
                                 prefix="Epoch: [{}]".format(epoch),
                                 logger=text_logger)

        end = time.time()

        for i, batch in enumerate(train_loader):
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate

            # measure data loading time
            data_time.update(time.time() - end)

            model.zero_grad()
            x, y = model.parse_batch(batch)
            y_pred = model(x)

            loss = criterion(y_pred, y)
            if hparams.distributed_run:
                reduced_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_loss = loss.item()

            losses.update(reduced_loss, x[0].size(0))

            if hparams.fp16_run:
                optimizer.backward(loss)
                grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hparams.grad_clip_thresh)

            grad_norms.update(grad_norm, x[0].size(0))

            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            logger.log_training(reduced_loss, grad_norm, learning_rate,
                                batch_time.val, iteration)

            if (i % hparams.display_freq == 0) and rank == 0:
                progress.print(i)

            iteration += 1

        if epoch % hparams.epochs_per_checkpoint == 0:
            validate(model, criterion, valset, iteration, hparams.batch_size,
                     n_gpus, collate_fn, logger, text_logger,
                     hparams.distributed_run, rank)
            if rank == 0:
                checkpoint_path = os.path.join(
                    output_directory, "checkpoint_{}".format(iteration))
                save_checkpoint(model, optimizer, learning_rate, iteration,
                                checkpoint_path)
Beispiel #11
0
def train(loss_last, data_loader, model_pos, criterion, optimizer, device, lr_init, lr_now, step, decay, gamma,
          max_norm=True, loss_3d=False, earlyend=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_loss = AverageMeter()

    # Switch to train mode
    torch.set_grad_enabled(True)
    model_pos.train()
    end = time.time()
    dataperepoch_limit = 1e10
    dataperepoch_count = 0
    if earlyend:
        dataperepoch_limit = 1

    bar = Bar('Train', max=len(data_loader))
    for i, (targets_score, inputs_2d, data_dict) in enumerate(data_loader):
        # Measure data loading time
        data_time.update(time.time() - end)
        num_poses = targets_score.size(0)
        dataperepoch_count += num_poses
        if dataperepoch_count >= dataperepoch_limit:
            break
        step += 1
        if step % decay == 0 or step == 1:
            lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)

        targets_score, inputs_2d = targets_score.to(device), inputs_2d.to(device)
        outputs_score = model_pos(inputs_2d)
        optimizer.zero_grad()

        ### 3d loss
        if loss_last < 0.3 or loss_3d:
            loss_3d = True
            for key in data_dict.keys():
                if isinstance(data_dict[key], torch.Tensor):
                    data_dict[key] = data_dict[key].to(device)
            output_3d, targets_3d = triangulation_acc(outputs_score, data_dict=data_dict, all_metric=False)
            loss_ = mpjpe(output_3d['ltr_after'], targets_3d)
        else:
            loss_ = criterion(outputs_score, targets_score)

        loss_.backward()

        if max_norm:
            nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
        optimizer.step()

        epoch_loss.update(loss_.item(), num_poses)

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                     '| Loss: {loss: .4f}' \
            .format(batch=i + 1, size=len(data_loader), data=data_time.val, bt=batch_time.avg,
                    ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss.avg)
        bar.next()

    bar.finish()
    return epoch_loss.avg, lr_now, step
Beispiel #12
0
def train(train_loader, model_gen, model_dis, optim_gen, optim_dis, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    gen_losses = AverageMeter()
    dis_losses = AverageMeter()

    # switch to train mode
    torch.set_grad_enabled(True)
    model_gen.train()
    model_dis.train()
    end = time.time()

    bar = Bar('Train', max=len(train_loader))
    for i, (img_src, norm_src, img_dst, norm_dst) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        x_src, x_dst, n_src, n_dst = img_src.to(device), img_dst.to(device), norm_src.to(device), norm_dst.to(device)
        batch_size = x_src.size(0)

        ######################
        # (1) Update D network
        ######################

        x_fake, w = model_gen(x_src, n_src, n_dst)

        eps = torch.rand(batch_size, 1).to(device)
        eps = eps.expand(-1, int(x_src.numel() / batch_size)).view_as(x_src)

        x_rand = eps * x_dst.detach() + (1 - eps) * x_fake.detach()
        x_rand.requires_grad_()
        x_rand = torch.cat([x_rand, n_dst], dim=1)
        loss_rand_x = model_dis(x_rand)

        grad_outputs = torch.ones(loss_rand_x.size())
        grads = autograd.grad(loss_rand_x, x_rand, grad_outputs=grad_outputs.to(device), create_graph=True)[0]
        loss_gp = torch.mean((grads.view(batch_size, -1).pow(2).sum(1).sqrt() - 1).pow(2))

        loss_real_x = model_dis(torch.cat([x_dst, n_dst], dim=1))
        loss_fake_x = model_dis(torch.cat([x_fake.detach(), n_dst], dim=1))
        loss_dis = loss_fake_x.mean() - loss_real_x.mean() + 10.0 * loss_gp

        # compute gradient and bp
        optim_dis.zero_grad()
        loss_dis.backward()
        optim_dis.step()

        dis_losses.update(float(loss_dis.item()))

        ######################
        # (2) Update G network
        ######################

        loss_fake_x = model_dis(torch.cat([x_fake, n_dst], dim=1))
        loss_gen = -loss_fake_x.mean() + 3.0 * loss_l1(x_fake, x_dst) + 0.05 * loss_norm_l1(w)

        # compute gradient and bp
        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        gen_losses.update(float(loss_gen.item()))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss G: {loss_g:.4f} | Loss D: {loss_d: .4f}'.format(
            batch=i + 1,
            size=len(train_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss_g=gen_losses.avg,
            loss_d=dis_losses.avg
        )
        bar.next()

    bar.finish()
    return gen_losses.avg, dis_losses.avg
Beispiel #13
0
    def fit(self,
            train_set,
            valid_set,
            run=None,
            max_steps=50,
            early_stopping_rounds=10,
            verbose=100):

        best_score = np.inf
        stop_steps = 0
        best_params = copy.deepcopy(self.state_dict())
        for step in range(max_steps):
            
            pprint('Step:', step)
            if stop_steps >= early_stopping_rounds:
                if verbose:
                    pprint('\tearly stop')
                break
            stop_steps += 1
            # training
            self.train()
            train_loss = AverageMeter()
            train_eval = AverageMeter()
            train_hids = dict()
            for i, (idx, data, label) in enumerate(train_set):
                data = torch.tensor(data, dtype=torch.float)
                label = torch.tensor(label, dtype=torch.float)
                if torch.cuda.is_available():
                    data, label = data.cuda(), label.cuda()
                train_hid, pred = self(data)
                loss = self.loss_fn(pred, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                loss_ = loss.item()
                eval_ = self.metric_fn(pred, label).item()
                train_loss.update(loss_, len(data))
                train_eval.update(eval_)
                train_hids[idx] = train_hid.cpu().detach()
                if verbose and i % verbose == 0:
                    pprint('iter %s: train_loss %.6f, train_eval %.6f' %
                           (i, train_loss.avg, train_eval.avg))
            # evaluation
            self.eval()
            valid_loss = AverageMeter()
            valid_eval = AverageMeter()
            valid_hids = dict()
            for i, (idx, data, label) in enumerate(valid_set):
                data = torch.tensor(data, dtype=torch.float)
                label = torch.tensor(label, dtype=torch.float)
                if torch.cuda.is_available():
                    data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    valid_hid, pred = self(data)
                
                loss = self.loss_fn(pred, label)
                valid_loss_ = loss.item()
                valid_eval_ = self.metric_fn(pred, label).item()
                valid_loss.update(valid_loss_, len(data))
                valid_eval.update(valid_eval_)
                valid_hids[idx] = valid_hid.cpu().detach()
            if run is not None:
                run.add_scalar('Train/Loss', train_loss.avg, step)
                run.add_scalar('Train/Eval', train_eval.avg, step)
                run.add_scalar('Valid/Loss', valid_loss.avg, step)
                run.add_scalar('Valid/Eval', valid_eval.avg, step)
            if verbose:
                pprint("current step: train_loss {:.6f}, valid_loss {:.6f}, "
                       "train_eval {:.6f}, valid_eval {:.6f}".format(
                           train_loss.avg, valid_loss.avg, train_eval.avg,
                           valid_eval.avg))
            if valid_eval.avg < best_score:
                if verbose:
                    pprint(
                        '\tvalid update from {:.6f} to {:.6f}, save checkpoint.'
                        .format(best_score, valid_eval.avg))
                best_score = valid_eval.avg
                stop_steps = 0
                best_params = copy.deepcopy(self.state_dict())
        # restore
        self.load_state_dict(best_params)
        return train_hids, valid_hids # train_hid: [batch, input_day, hid_size]
Beispiel #14
0
    def fit(self,
            train_set,
            valid_set,
            train_hids, # daily RNN hids
            valid_hids,
            run=None,
            min_max_steps=50,
            min_early_stopping_rounds=5,
            verbose=100,
            output_path = "/home/amax/Documents/HM_CNN_RNN/out",
            itera=0):

        best_score = np.inf
        stop_steps = 0
        best_params = copy.deepcopy(self.state_dict())

        # for step in range(max_steps):
        for step in range(min_max_steps): 
            # self.min_ratio_teacher = 1 / (step+1)
            # if self.min_ratio_teacher <= 0.1:
            #     self.min_ratio_teacher = 0 
            pprint('Step:', step)
            if stop_steps >= min_early_stopping_rounds:
                if verbose:
                    pprint('\tearly stop')
                break
            stop_steps += 1
            # training
            self.train()
            train_loss = AverageMeter()
            train_loss_a = AverageMeter()
            train_loss_b = AverageMeter()
            train_eval = AverageMeter()
            train_eval_a = AverageMeter()
            train_eval_b = AverageMeter()
            train_day_reps = dict() # min -> day representation
            for i, (idx, data, label) in enumerate(train_set):
                self.optimizer.zero_grad()
                data = torch.tensor(data, dtype=torch.float)
                label = torch.tensor(label, dtype=torch.float)
                train_hid = train_hids[idx] # train_hid:[batch, input_day, input_size]
                if torch.cuda.is_available():
                    data, train_hid, label  = data.cuda(), train_hid.cuda(), label.cuda()
                day_rep, pred = self(data)

                #[batch, input_day, hid_size ]
                train_day_rep = day_rep[:,:,:self.hid_size]
                # pprint(f'train_day_rep: {train_day_rep.shape}')
                # pprint(f'train_hid: {train_hid.shape}')
                loss_a = self.loss_fn(train_day_rep, train_hid, dim=[0,1,2]) # learn from teacher
                
                loss_b = self.loss_fn(pred, label, dim=0) # learn from label 2
                # loss = self.min_ratio_teacher * loss_a + (1.0-self.min_ratio_teacher) * loss_b
                loss = self.min_ratio_teacher * loss_a + loss_b
                
                loss.backward()
                self.optimizer.step()
                train_day_reps[idx] = day_rep.cpu().detach()
                len_data = len(data)
                train_loss.update(loss.item(), len_data)
                if loss.item() > 1000:
                    pprint(idx)
                train_loss_a.update(loss_a.item(), len_data)
                train_loss_b.update(loss_b.item(), len_data)
                eval_a = self.metric_fn(train_day_rep, train_hid, dim=[0,1,2]).item()
                eval_b = self.metric_fn(pred, label,dim=0).item()
                # eval_ = self.min_ratio_teacher * eval_a + (1.0-self.min_ratio_teacher) * eval_b
                eval_ = eval_b
                train_eval.update(eval_)
                train_eval_a.update(eval_a)
                train_eval_b.update(eval_b)
                if verbose and i % verbose == 0:
                    pprint('iter %s: train_loss %.6f, train_eval %.6f' %
                           (i, train_loss.avg, train_eval.avg))
            # evaluation
            self.eval()
            valid_loss = AverageMeter()
            valid_loss_a = AverageMeter()
            valid_loss_b = AverageMeter()
            valid_eval = AverageMeter()
            valid_eval_a = AverageMeter()
            valid_eval_b = AverageMeter()
            valid_day_reps = dict()
            for i, (idx, data, label) in enumerate(valid_set):
                data = torch.tensor(data, dtype=torch.float)
                label = torch.tensor(label, dtype=torch.float)
                valid_hid = valid_hids[idx]
                if torch.cuda.is_available():
                    data, valid_hid, label = data.cuda(), valid_hid.cuda(), label.cuda()
                with torch.no_grad():
                    day_rep, pred = self(data)
                valid_day_rep = day_rep[:,:,:self.hid_size]
                loss_a = self.loss_fn(valid_day_rep, valid_hid, dim=[0,1,2])
                loss_b = self.loss_fn(pred, label, dim=0)
                # loss = self.min_ratio_teacher * loss_a + (1.0-self.min_ratio_teacher) * loss_b
                loss = self.min_ratio_teacher * loss_a + loss_b
                valid_day_reps[idx] = day_rep.cpu().detach()
                len_data = len(data)
                valid_loss.update(loss.item(), len_data)
                valid_loss_a.update(loss_a.item(), len_data)
                valid_loss_b.update(loss_b.item(), len_data)
                eval_a = self.metric_fn(valid_day_rep, valid_hid,dim=[0,1,2]).item()
                eval_b = self.metric_fn(pred, label,dim=0).item()
                # eval_ = self.min_ratio_teacher * eval_a + (1.0-self.min_ratio_teacher) * eval_b
                eval_ = eval_b
                valid_eval.update(eval_)
                valid_eval_a.update(eval_a)
                valid_eval_b.update(eval_b)
            if run is not None:
                run.add_scalar('Train/Loss_total', train_loss.avg, step)
                run.add_scalar('Train/Loss_from_teacher', train_loss_a.avg, step)
                run.add_scalar('Train/Loss_from_label', train_loss_b.avg, step)
                run.add_scalar('Train/Eval_total', train_eval.avg, step)
                run.add_scalar('Train/Eval_from_teacher', train_eval_a.avg, step)
                run.add_scalar('Train/Eval_from_label', train_eval_b.avg, step)

                run.add_scalar('Valid/Loss_total', valid_loss.avg, step)
                run.add_scalar('Valid/Loss_from_teacher', valid_loss_a.avg, step)
                run.add_scalar('Valid/Loss_from_label', valid_loss_b.avg, step)
                run.add_scalar('Valid/Eval_total', valid_eval.avg, step)
                run.add_scalar('Valid/Eval_from_teacher', valid_eval_a.avg, step)
                run.add_scalar('Valid/Eval_from_label', valid_eval_b.avg, step)
            if verbose:
                pprint("current step: train_loss {:.6f}, valid_loss {:.6f}, "
                       "train_eval {:.6f}, valid_eval {:.6f}".format(
                           train_loss.avg, valid_loss.avg, train_eval.avg,
                           valid_eval.avg))
            if valid_eval.avg < best_score:
                if verbose:
                    pprint(
                        '\tvalid update from {:.6f} to {:.6f}, save checkpoint.'
                        .format(best_score, valid_eval.avg))
                best_train_day_reps = copy.deepcopy(train_day_reps)
                best_valid_day_reps = copy.deepcopy(valid_day_reps)
                best_score = valid_eval.avg
                stop_steps = 0
                best_params = copy.deepcopy(self.state_dict())

                self.save(output_path+'/best_min_model_%d.bin'%itera)
        # restore
        self.load_state_dict(best_params)
        return best_train_day_reps, best_valid_day_reps
Beispiel #15
0
    def prevalid(self):
        from common.triangulation import triangulation_acc
        from common.utils import AverageMeter
        from common.loss import mpjpe
        acc_pre = AverageMeter()
        acc_best = AverageMeter()
        acc_svd = AverageMeter()
        for i in range(len(self)):
            target_score, input, data_dict = self.__getitem__(i)
            target_score = target_score.unsqueeze(0)
            for item in data_dict.keys():
                data_dict[item] = data_dict[item].unsqueeze(0)
            output_3d_dict, targets_3d = triangulation_acc(target_score,
                                                           data_dict,
                                                           all_metric=True)
            acc_pre.update(mpjpe(output_3d_dict['ltr_before'], targets_3d))
            acc_best.update(mpjpe(output_3d_dict['ltr_best'], targets_3d))
            acc_svd.update(mpjpe(output_3d_dict['ltr_svd'], targets_3d))

        acc_str = "before: {:.5f}, best: {:.5f}, " \
                  "svd: {:.5f}".format(acc_pre.avg, acc_best.avg, acc_svd.avg)
        print("Pre validation: ")
        print(acc_str)