Exemplo n.º 1
0
def train(train_loader, model, criterion, optimizer, epoch):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    for i, (input, target) in enumerate(train_loader):

        input, target = input.to(device), target.to(device)

        model.zero_grad()

        # compute output and loss
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.detach(), target.detach().cpu(), topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % config["print_freq"] == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch, i, len(train_loader), loss=losses, top1=top1, top5=top5))
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 2
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        input_vars = torch.autograd.Variable(input.cuda())
        target_var = torch.autograd.Variable(target.cuda(async=True))

        model.zero_grad()

        # compute output and loss
        output = model(input_vars)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config["print_freq"] == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 3
0
def validate(val_loader, model, criterion, class_to_idx=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    logits_matrix = []
    targets_list = []

    end = time.time()
    for i, (input, target) in enumerate(val_loader):

        input_vars = torch.autograd.Variable(input.cuda(), volatile=True)
        target_var = torch.autograd.Variable(target.cuda(async=True),
                                             volatile=True)

        # compute output and loss
        output = model(input_vars)
        loss = criterion(output, target_var)

        if args.eval_only:
            logits_matrix.append(output.cpu().data.numpy())
            targets_list.append(target.cpu().numpy())

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config["print_freq"] == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      i,
                      len(val_loader),
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1,
                                                                  top5=top5))

    if args.eval_only:
        logits_matrix = np.concatenate(logits_matrix)
        targets_list = np.concatenate(targets_list)
        youtube_ids_list = np.asarray(youtube_ids_list)
        print(logits_matrix.shape, targets_list.shape, youtube_ids_list.shape)
        save_results(logits_matrix, targets_list, class_to_idx, config)
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 4
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        if config['nclips_train'] > 1:
            input_var = list(input.split(config['clip_size'], 2))
            for idx, inp in enumerate(input_var):
                input_var[idx] = inp.to(device)
        else:
            input_var = [input.to(device)]

        target1 = target // 8
        target2 = target % 8
        target1 = target1.to(device)
        target2 = target2.to(device)


        model.zero_grad()

        # compute output and loss
        output = model(input_var)

        loss_list = []
        for i in range(50):
            loss_list.append(criterion(output[i], target1))
        for i in range(50):
            loss_list.append(criterion(output[50 + i], target2))
        loss = torch.stack(loss_list).sum()

        # measure accuracy and record loss
        prec1_1, prec5_1 = accuracy(output[49].detach().cpu(), target1.detach().cpu(), topk=(1, 5))
        prec1_2, prec5_2 = accuracy(output[-1].detach().cpu(), target2.detach().cpu(), topk=(1, 5))
        prec1, prec5 = ((prec1_1 + prec1_2) / 2.0, (prec5_1 + prec5_2) / 2.0)
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config["print_freq"] == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1, top5=top5))
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 5
0
def validate(val_loader, model, criterion, class_to_idx=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    logits_matrix = []
    features_matrix = []
    targets_list = []
    item_id_list = []

    end = time.time()
    with torch.no_grad():
        for i, (input, target, item_id) in enumerate(val_loader):

            if config['nclips_val'] > 1:
                input_var = list(input.split(config['clip_size'], 2))
                for idx, inp in enumerate(input_var):
                    input_var[idx] = inp.to(device)
            else:
                input_var = [input.to(device)]

            target1 = target // 8
            target2 = target % 8
            target1 = target1.to(device)
            target2 = target2.to(device)

            '''
            if input_var[-1].shape != (8, 3, 72, 84, 84):
                print(f'len(input_var) : {len(input_var)}')
                print(f'input_var[-1].shape : {input_var[-1].shape}')
                print(f'target.shape : {target.shape}')
            '''    

            # compute output and loss
            output = model(input_var)

            '''
            if input_var[-1].shape != (8, 3, 72, 84, 84):
                print(f'output.shape : {output.shape}')
            '''

            loss_list = []
            for i in range(50):
                loss_list.append(criterion(output[i], target1))
            for i in range(50):
                loss_list.append(criterion(output[50 + i], target2))
            loss = torch.stack(loss_list).sum()
            
            # measure accuracy and record loss
            prec1_1, prec5_1 = accuracy(output[49].detach().cpu(), target1.detach().cpu(), topk=(1, 5))
            prec1_2, prec5_2 = accuracy(output[-1].detach().cpu(), target2.detach().cpu(), topk=(1, 5))
            prec1, prec5 = ((prec1_1 + prec1_2) / 2.0, (prec5_1 + prec5_2) / 2.0)
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % config["print_freq"] == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
Exemplo n.º 6
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target, item_id) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        if config['nclips_train_val'] > 1:
            input_var = list(input.split(config['clip_size'], 2))
            for idx, inp in enumerate(input_var):
                input_var[idx] = inp.to(device)
        else:
            input_var = [input.to(device)]

        target = target.to(device)

        model.zero_grad()

        # compute output and loss
        output = model(input_var)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = utils.accuracy(output.detach().cpu(),
                                      target.detach().cpu(),
                                      topk=(1, 5))
        # for frame_ind in range(input_var[0].shape[2]):
        #     wandb.log({f"input_var_{frame_ind}": wandb.Image(input_var[0][0, :, frame_ind, :])})
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        wandb.log({'train_loss': loss.item()})
        wandb.log({'train_prec1': prec1.item()})
        wandb.log({'train_prec5': prec5.item()})
        wandb.log({'train_top1_avg': top1.avg})
        wandb.log({'train_top5_avg': top5.avg})

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config["print_freq"] == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 7
0
def validate(val_loader, model, criterion, class_to_idx=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    logits_matrix = []
    features_matrix = []
    targets_list = []
    item_id_list = []

    end = time.time()
    with torch.no_grad():
        pbar = tqdm(val_loader)
        for i, (input, target, item_id) in enumerate(pbar):

            if config['nclips_val'] > 1:
                input_var = list(input.split(config['clip_size'], 2))
                for idx, inp in enumerate(input_var):
                    input_var[idx] = inp.to(device)
            else:
                input_var = [input.to(device)]

            target = target.to(device)

            # compute output and loss
            output, features = model(input_var, config['save_features'])
            loss = criterion(output, target)

            if args.eval_only:
                logits_matrix.append(output.cpu().data.numpy())
                features_matrix.append(features.cpu().data.numpy())
                targets_list.append(target.cpu().numpy())
                item_id_list.append(item_id)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.detach().cpu(),
                                    target.detach().cpu(),
                                    topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % config["print_freq"] == 0:
                print_str = (
                    'Test: [{0}/{1}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        i,
                        len(val_loader),
                        batch_time=batch_time,
                        loss=losses,
                        top1=top1,
                        top5=top5))
                pbar.set_description(print_str)
                # print(print_str)

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1,
                                                                  top5=top5))
    sys.stdout.flush()

    if args.eval_only:
        logits_matrix = np.concatenate(logits_matrix)
        features_matrix = np.concatenate(features_matrix)
        targets_list = np.concatenate(targets_list)
        item_id_list = np.concatenate(item_id_list)
        print(logits_matrix.shape, targets_list.shape, item_id_list.shape)
        save_results(logits_matrix, features_matrix, targets_list,
                     item_id_list, class_to_idx, config)
        get_submission(logits_matrix, item_id_list, class_to_idx, config)
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 8
0
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    model.train()

    end = time.time()
    for i, (seq, pos, neg) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        # reset model gradients
        model.zero_grad()

        # compute output and loss
        seq_emb, pos_emb, neg_emb = model(seq, pos, neg)
        loss = multiple_binary_cross_entropy(seq_emb, pos_emb, pos,
                                             neg_emb).to(device)
        # measure accuracy and record loss
        losses.update(loss.item(), seq.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

    return losses.avg
Exemplo n.º 9
0
def evaluate(data_eval, model, epoch, eval):
    top1 = AverageMeter()
    top10 = AverageMeter()
    nDCG10 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (seq, item_idx) in enumerate(data_eval):
            # compute output and loss
            seq_emb, test_emb = model(seq, item_idx, predict=True)
            test_logits = torch.matmul(seq_emb, test_emb.t())
            test_logits = -test_logits.view(seq.size()[0],
                                            seq.size()[1], 101)[:, -1, :][0]
            prec1, prec10, nDCG = accuracy(test_logits)

            # update metrics
            top1.update(prec1, 1)
            top10.update(prec10, 1)
            nDCG10.update(nDCG, 1)

    print(
        '-- {eval} Results Epoch [{epoch}] -- \t'
        '* HitRate@1 {top1.avg:.3f} - HitRate@10 {top10.avg:.3f} * nDCG@10 {nDCG10.avg:.3f}'
        .format(eval=eval, epoch=epoch, top1=top1, top10=top10, nDCG10=nDCG10))
    return top1.avg, top10.avg, nDCG10.avg
Exemplo n.º 10
0
def train(train_loader, model, optimizer, epoch, criterion, tb_logger=None):
    global args
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc_top1 = AverageMeter()
    acc_top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (global_img_tensors, box_tensors, box_categories,
            video_label) in enumerate(train_loader):
        model.zero_grad()
        # measure data loading time
        data_time.update(time.time() - end)

        # local_img_tensor is (b, nr_frames, nr_boxes, 3, h, w)
        # global_img_tensor is (b, nr_frames, 3, h, w)

        # compute output

        output = model(global_img_tensors, box_categories, box_tensors,
                       video_label)
        output = output.view((-1, len(train_loader.dataset.classes)))
        loss = criterion(output, video_label.long().cuda())

        acc1, acc5 = accuracy(output.cpu(), video_label, topk=(1, 5))

        # measure accuracy and record loss
        losses.update(loss.item(), global_img_tensors.size(0))
        acc_top1.update(acc1.item(), global_img_tensors.size(0))
        acc_top5.update(acc5.item(), global_img_tensors.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if args.clip_gradient is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.clip_gradient)
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc1 {acc_top1.val:.1f} ({acc_top1.avg:.1f})\t'
                  'Acc5 {acc_top5.val:.1f} ({acc_top5.avg:.1f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      acc_top1=acc_top1,
                      acc_top5=acc_top5))

        # log training data into tensorboard
        if tb_logger is not None and i % args.log_freq == 0:
            logs = OrderedDict()
            logs['Train_IterLoss'] = losses.val
            logs['Train_Acc@1'] = acc_top1.val
            logs['Train_Acc@5'] = acc_top5.val
            # how many iterations we have trained
            iter_count = epoch * len(train_loader) + i
            for key, value in logs.items():
                tb_logger.log_scalar(value, key, iter_count)

            tb_logger.flush()
Exemplo n.º 11
0
def validate(val_loader,
             model,
             criterion,
             epoch=None,
             tb_logger=None,
             class_to_idx=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    acc_top1 = AverageMeter()
    acc_top5 = AverageMeter()
    logits_matrix = []
    targets_list = []
    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (global_img_tensors, box_tensors, box_categories,
            video_label) in enumerate(val_loader):
        # compute output
        with torch.no_grad():
            output = model(global_img_tensors,
                           box_categories,
                           box_tensors,
                           video_label,
                           is_inference=True)
            output = output.view((-1, len(val_loader.dataset.classes)))
            loss = criterion(output, video_label.long().cuda())

            acc1, acc5 = accuracy(output.cpu(), video_label, topk=(1, 5))
            if args.evaluate:
                logits_matrix.append(output.cpu().data.numpy())
                targets_list.append(video_label.cpu().numpy())

        # measure accuracy and record loss
        losses.update(loss.item(), global_img_tensors.size(0))
        acc_top1.update(acc1.item(), global_img_tensors.size(0))
        acc_top5.update(acc5.item(), global_img_tensors.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 or i + 1 == len(val_loader):
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc1 {acc_top1.val:.1f} ({acc_top1.avg:.1f})\t'
                  'Acc5 {acc_top5.val:.1f} ({acc_top5.avg:.1f})'.format(
                      i,
                      len(val_loader),
                      batch_time=batch_time,
                      loss=losses,
                      acc_top1=acc_top1,
                      acc_top5=acc_top5))

    if args.evaluate:
        logits_matrix = np.concatenate(logits_matrix)
        targets_list = np.concatenate(targets_list)
        save_results(logits_matrix, targets_list, class_to_idx, args)

    if epoch is not None and tb_logger is not None:
        logs = OrderedDict()
        logs['Val_EpochLoss'] = losses.avg
        logs['Val_EpochAcc@1'] = acc_top1.avg
        logs['Val_EpochAcc@5'] = acc_top5.avg
        # how many iterations we have trained
        for key, value in logs.items():
            tb_logger.log_scalar(value, key, epoch + 1)

        tb_logger.flush()

    return losses.avg
def validate_ensemble(val_loader,
                      classifier,
                      model0,
                      model1,
                      model2,
                      model3,
                      model4,
                      list_id_files,
                      criterion,
                      class_to_idx=None):

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model0.eval()
    model1.eval()
    model2.eval()
    model3.eval()
    model4.eval()
    classifier.eval()

    logits_matrix = []
    targets_list = []
    label_list = []
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):

            input, target = input.to(device), target.to(device)

            # compute output and loss
            output0, feature0 = model0(input)
            output1, feature1 = model1(input)
            output2, feature2 = model2(input)
            output3, feature3 = model3(input)
            output4, feature4 = model4(input)
            #sav=torch.cat((feature0,feature1,feature2,feature3,feature4),1)
            sav = torch.cat((output0, output1, output2, output3, output4), 1)
            class_video = classifier(sav)
            loss = criterion(class_video, target)
            if args.eval_only:
                logits_matrix.append(class_video.detach().cpu().numpy())
                targets_list.append(target.detach().cpu().numpy())
                _, predicted = torch.max(class_video.data, 1)
                label_list.append(predicted.detach().cpu().numpy())
                total += target.size(0)
                correct += (predicted == target).sum()

            # measure accuracy and record loss
            prec1, prec5 = accuracy(class_video.detach(),
                                    target.detach().cpu(),
                                    topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            if i % config["print_freq"] == 0:
                print('Test: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          loss=losses,
                          top1=top1,
                          top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))

        if args.eval_only:
            logits_matrix = np.concatenate(logits_matrix)
            targets_list = np.concatenate(targets_list)
            label_list = np.concatenate(label_list)
            print('Accuracy of the model: %d %%' % (100 * correct / total))
            print('Accuracy2 of the model: %d %%' %
                  (100 * ((label_list == targets_list).sum()) / total))

            path_to_save2 = os.path.join(config['output_dir'],
                                         config['model_name'],
                                         "test_results.csv")
            with open(path_to_save2, mode='w') as csv_file:
                my_csv_writer = csv.writer(csv_file,
                                           delimiter=';',
                                           quotechar='"',
                                           quoting=csv.QUOTE_MINIMAL)
                for i in range(len(list_id_files)):
                    my_csv_writer.writerow(
                        [list_id_files[i], class_to_idx[label_list[i]]])

            print(logits_matrix.shape, targets_list.shape)
            print(class_to_idx)
            save_results(logits_matrix, targets_list, class_to_idx, config)

        return losses.avg, top1.avg, top5.avg
def trainEnsemble():
    global args, best_prec1

    # set run output folder
    model_name = "classifier"
    output_dir = config["output_dir"]

    save_dir = os.path.join(output_dir, model_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'plots'))

    # adds a handler for Ctrl+C
    def signal_handler(signal, frame):
        """
        Remove the output dir, if you exit with Ctrl+C and
        if there are less then 3 files.
        It prevents the noise of experimental runs.
        """
        num_files = len(glob.glob(save_dir + "/*"))
        if num_files < 1:
            shutil.rmtree(save_dir)
        print('You pressed Ctrl+C!')
        sys.exit(0)

    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, signal_handler)

    # create model
    #model = ConvColumn(config['num_classes'])

    model0 = ConvColumn6(config['num_classes'])
    model0 = torch.nn.DataParallel(model0, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv6/checkpoint.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv6/checkpoint.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model0.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv6/checkpoint.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model1 = ConvColumn7(config['num_classes'])
    model1 = torch.nn.DataParallel(model1, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv7/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv7/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model1.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv7/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    classifier = Classifier(config['num_classes'])
    classifier = torch.nn.DataParallel(classifier, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/classifier/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/classifier/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        classifier.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/classifier/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model3 = ConvColumn9(config['num_classes'])
    model3 = torch.nn.DataParallel(model3, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv9/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv9/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model3.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv9/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model2 = ConvColumn8(config['num_classes'])
    model2 = torch.nn.DataParallel(model2, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv8/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv8/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model2.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv8/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model4 = ConvColumn5(config['num_classes'])
    model4 = torch.nn.DataParallel(model4, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/ConvColumn5/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/ConvColumn5/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model4.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/ConvColumn5/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    transform_train = Compose([
        RandomAffine(degrees=[-10, 10],
                     translate=[0.15, 0.15],
                     scale=[0.9, 1.1],
                     shear=[-5, 5]),
        CenterCrop(84),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    transform_valid = Compose([
        CenterCrop(84),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_data = VideoFolder(
        root=config['train_data_folder'],
        csv_file_input=config['train_data_csv'],
        csv_file_labels=config['labels_csv'],
        clip_size=config['clip_size'],
        nclips=1,
        step_size=config['step_size'],
        is_val=False,
        transform=transform_train,
    )

    print(" > Using {} processes for data loader.".format(
        config["num_workers"]))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True)

    val_data = VideoFolder(
        root=config['val_data_folder'],
        csv_file_input=config['val_data_csv'],
        csv_file_labels=config['labels_csv'],
        clip_size=config['clip_size'],
        nclips=1,
        step_size=config['step_size'],
        is_val=True,
        transform=transform_valid,
    )

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             pin_memory=True,
                                             drop_last=False)

    list_id_files = []
    for i in val_data.csv_data:
        list_id_files.append(i.path[16:])
    print(len(list_id_files))

    ###########

    assert len(train_data.classes) == config["num_classes"]

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().to(device)

    # define optimizer
    lr = config["lr"]
    last_lr = config["last_lr"]
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, amsgrad=True)

    #torch.optim.SGD(classifier.parameters(), lr,
    #momentum=momentum,
    #weight_decay=weight_decay)

    # set callbacks
    plotter = PlotLearning(os.path.join(save_dir, "plots"),
                           config["num_classes"])
    lr_decayer = MonitorLRDecay(0.6, 3)
    val_loss = 9999999

    # set end condition by num epochs
    num_epochs = int(config["num_epochs"])
    if num_epochs == -1:
        num_epochs = 999999

    if args.test_only:
        print("test")
        test_data = VideoFolder_test(
            root=config['val_data_folder'],
            csv_file_input=config['test_data_csv'],
            clip_size=config['clip_size'],
            nclips=1,
            step_size=config['step_size'],
            is_val=True,
            transform=transform_valid,
        )

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            pin_memory=True,
            drop_last=False)

        list_id_files_test = []
        for i in test_data.csv_data:
            list_id_files_test.append(i.path[16:])
        print(len(list_id_files_test))
        test_ensemble(test_loader, classifier, model1, model2, model3,
                      list_id_files_test, criterion, train_data.classes_dict)
        return

    if args.eval_only:
        val_loss, val_top1, val_top5 = validate_ensemble(
            val_loader, classifier, model1, model2, model3, list_id_files,
            criterion, train_data.classes_dict)
        return

    # switch to evaluate mode
    model0.eval()
    model1.eval()
    model2.eval()
    model3.eval()
    model4.eval()
    classifier.train()

    logits_matrix = []
    targets_list = []

    new_input = np.array([])
    train_writer = tensorboardX.SummaryWriter("logs")

    for epoch in range(0, num_epochs):
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        lr = lr_decayer(val_loss, lr)
        print(" > Current LR : {}".format(lr))

        if lr < last_lr and last_lr > 0:
            print(" > Training is done by reaching the last learning rate {}".
                  format(last_lr))
            sys.exit(1)
        for i, (input, target) in enumerate(train_loader):
            input, target = input.to(device), target.to(device)

            with torch.no_grad():

                # compute output and loss
                output0, feature0 = model0(input)
                output1, feature1 = model1(input)
                output2, feature2 = model2(input)
                output3, feature3 = model3(input)
                output4, feature4 = model4(input)
                #sav=torch.cat((feature0,feature1,feature2,feature3,feature4),1)
                sav = torch.cat((output0, output1, output2, output3, output4),
                                1)
            classifier.zero_grad()
            class_video = classifier(sav)
            loss = criterion(class_video, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(class_video.detach(),
                                    target.detach().cpu(),
                                    topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % config["print_freq"] == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          0,
                          i,
                          len(train_loader),
                          loss=losses,
                          top1=top1,
                          top5=top5))

        val_loss, val_top1, val_top5 = validate_ensemble(
            val_loader, classifier, model0, model1, model2, model3, model4,
            list_id_files, criterion)

        train_writer.add_scalar('loss', loss, losses.avg)
        train_writer.add_scalar('top1', top1.avg, epoch + 1)
        train_writer.add_scalar('top5', top5.avg, epoch + 1)

        train_writer.add_scalar('val_loss', val_loss, epoch + 1)
        train_writer.add_scalar('val_top1', val_top1, epoch + 1)
        train_writer.add_scalar('val_top5', val_top5, epoch + 1)

        # remember best prec@1 and save checkpoint
        is_best = val_top1 > best_prec1
        best_prec1 = max(val_top1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': "Classifier",
                'state_dict': classifier.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, config)
Exemplo n.º 14
0
def predict(model, result_file_path=args.result_path):
    global args
    
    acc_top1 = AverageMeter()
    acc_top5 = AverageMeter()
    # Pickle path for intermediate batch result:
    temp_pickle = os.path.splitext(result_file_path)[0] + '.p'
    if os.path.exists(temp_pickle):
        os.remove(temp_pickle)

    #model = torch.nn.DataParallel(model)
    
    # Read Model in
    checkpoint = torch.load(args.model_state_dict_path)
    model.load_state_dict(checkpoint['state_dict'])
    epoch = checkpoint['epoch']
    best_loss = checkpoint['best_loss']
    
    # Set eval mode
    model.eval()

    # Data Loader
    dataset = VideoFolder(
                        root=args.root_frames,
                        num_boxes=args.num_boxes,
                        file_input=args.json_data_list,
                        file_labels=args.json_file_labels,
                        word2vec_weights=args.word2vec_weights_path, 
                        frames_duration=args.num_frames,
                        video_root=args.video_root,
                        args=args,
                        is_val=True,
                        if_augment=True,
                        model=args.model
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, drop_last=True,
        pin_memory=True
    )
    



    for i, (vid_names, frame_list, global_img_tensors, box_tensors, box_categories, word2vec_features, video_label) in enumerate(data_loader):    

        # Move data to GPU
        global_img_tensors = global_img_tensors.to(cuda_device)
        box_categories = box_categories.to(cuda_device)
        box_tensors = box_tensors.to(cuda_device)
        video_label = video_label.to(cuda_device)
        if not isinstance(word2vec_features, list):       # word2vec has dummpy [] holder if no word2vec provided
            word2vec_features = word2vec_features.to(cuda_device)

        with torch.no_grad():
            output = model(global_img_tensors, box_categories, box_tensors, word2vec_features, video_label)
            attention = model.batch_attention_weight
            print("Atten shape", attention.size())
            output = output.view((-1, len(data_loader.dataset.classes)))

            [acc1, acc5], pred = accuracy(output, video_label, topk=(1, 5), return_predict=True)
            acc_top1.update(acc1.item(), global_img_tensors.size(0))
            acc_top5.update(acc5.item(), global_img_tensors.size(0))

            with open(temp_pickle, 'ab+') as fp:
                
                result_dict = {
                    'vid_names' : vid_names,
                    'frame_list' : frame_list.int().tolist(),          # Convert tensor to int type and then to list.
                    'video_label' : video_label.int().cpu().numpy(),  
                    'prediction' : pred.int().cpu().numpy(),
                    'attention' : attention.cpu().numpy()
                } 

                pickle.dump(result_dict, fp)
            
            batch_result = 'Epoch: [{0}][{1}/{2}]\t' \
                    'Acc1 {acc_top1.val:.1f} ({acc_top1.avg:.1f})\t' \
                    'Acc5 {acc_top5.val:.1f} ({acc_top5.avg:.1f})'.format(
                    epoch, i, len(data_loader), acc_top1=acc_top1, acc_top5=acc_top5)

            print(batch_result)
    
    aggre_batch_result(temp_pickle, result_file_path=result_file_path)
Exemplo n.º 15
0
def validate(val_loader, model, criterion, class_to_idx=None):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    logits_matrix = []
    targets_list = []

    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):

            input, target = input.to(device), target.to(device)

            # compute output and loss
            output = model(input)
            loss = criterion(output, target)

            if args.eval_only:
                logits_matrix.append(output.detach().cpu().numpy())
                targets_list.append(target.detach().cpu().numpy())

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.detach(), target.detach().cpu(), topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            if i % config["print_freq"] == 0:
                print('Test: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i, len(val_loader), loss=losses, top1=top1, top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

        if args.eval_only:
            logits_matrix = np.concatenate(logits_matrix)
            targets_list = np.concatenate(targets_list)
            print(logits_matrix.shape, targets_list.shape)
            save_results(logits_matrix, targets_list, class_to_idx, config)
        return losses.avg, top1.avg, top5.avg
Exemplo n.º 16
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    pbar = tqdm(train_loader)
    for i, (input, target) in enumerate(pbar):

        # measure data loading time
        data_time.update(time.time() - end)

        if config['nclips_train'] > 1:
            input_var = list(input.split(config['clip_size'], 2))
            for idx, inp in enumerate(input_var):
                input_var[idx] = inp.to(device)
        else:
            input_var = [input.to(device)]

        target = target.to(device)

        model.zero_grad()

        # compute output and loss
        output = model(input_var)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.detach().cpu(),
                                target.detach().cpu(),
                                topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config["print_freq"] == 0:
            print_str = ('Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                         'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                             epoch,
                             i,
                             len(train_loader),
                             batch_time=batch_time,
                             data_time=data_time,
                             loss=losses,
                             top1=top1,
                             top5=top5))
            pbar.set_description(print_str)
            # print(print_str)
    return losses.avg, top1.avg, top5.avg
Exemplo n.º 17
0
        is_val=True,
        transform_pre=transform_eval_pre,
        transform_post=transform_post,
        get_item_id=True,
    )

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             pin_memory=True,
                                             drop_last=False)

    criterion = nn.CrossEntropyLoss().to(device)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    logits_matrix = []
    features_matrix = []
    targets_list = []
    item_id_list = []

    end = time.time()

    with torch.no_grad():
Exemplo n.º 18
0
def train(train_loader, model, optimizer, epoch, criterion, tb_logger=None):
    global args
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc_top1 = AverageMeter()
    acc_top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    # check_gpu('Start')
    #for i, (vid_names, frame_tensors, global_img_tensors, box_tensors, box_categories, word2vec_features, video_label) in enumerate(train_loader):
    for i, (vid_names, frame_list, global_img_tensors, box_tensors,
            box_categories, word2vec_features,
            video_label) in enumerate(train_loader):

        ################################################################
        # output to tensorboard
        # print("print frame shape:", frame_tensors.shape)  #[72,8,3,224,224]
        # print("len vid_names:", len(vid_names))

        #for v, video in enumerate(frame_tensors):

        #print("video shape:", video.shape)         #[8, 3, H, W]
        #frames = [frame for frame in video]       #[8, 3, H, W]
        #print("frames shape:", frames[0].shape)   #[3, H, W]

        #img_grid = torchvision.utils.make_grid(video, nrow=8)
        #print("image grid shape:", img_grid.shape)  #[3, H, 8*W]
        #transposed_frame = np.transpose(img_grid.numpy(), (1, 2, 0))
        #print("transposed shape:", transposed_frame)
        #plt.imshow(transposed_frame)

        #writer.add_images(vid_names[v], video)
        ################################################################

        model.zero_grad()
        # measure data loading time
        data_time.update(time.time() - end)

        # local_img_tensor is (b, nr_frames, nr_boxes, 3, h, w)
        # global_img_tensor is (b, nr_frames, 3, h, w)

        # compute output
        global_img_tensors = global_img_tensors.to(cuda_device)
        box_categories = box_categories.to(cuda_device)
        box_tensors = box_tensors.to(cuda_device)
        video_label = video_label.to(cuda_device)

        #print("type w2v:", word2vec_features)
        if not isinstance(
                word2vec_features,
                list):  # word2vec has dummpy [] holder if no word2vec provided
            word2vec_features = word2vec_features.to(cuda_device)

        output = model(global_img_tensors, box_categories, box_tensors,
                       word2vec_features, video_label)
        output = output.view((-1, len(train_loader.dataset.classes)))
        # check_gpu('after train')

        #loss = criterion(output, video_label.long().cuda())
        loss = criterion(output, video_label)
        # check_gpu('after loss')

        acc1, acc5 = accuracy(output.cpu(), video_label.cpu(), topk=(1, 5))

        # measure accuracy and record loss
        losses.update(loss.item(), global_img_tensors.size(0))
        acc_top1.update(acc1.item(), global_img_tensors.size(0))
        acc_top5.update(acc5.item(), global_img_tensors.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if args.clip_gradient is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.clip_gradient)
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            batch_result = 'Epoch: [{0}][{1}/{2}]\t' \
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                  'Acc1 {acc_top1.val:.1f} ({acc_top1.avg:.1f})\t' \
                  'Acc5 {acc_top5.val:.1f} ({acc_top5.avg:.1f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, acc_top1=acc_top1, acc_top5=acc_top5)
            print(batch_result)

        # log training data into tensorboard
        if tb_logger is not None and i % args.log_freq == 0:
            logs = OrderedDict()
            logs['Train_IterLoss'] = losses.val
            logs['Train_Acc@1'] = acc_top1.val
            logs['Train_Acc@5'] = acc_top5.val
            # how many iterations we have trained
            iter_count = epoch * len(train_loader) + i
            for key, value in logs.items():
                tb_logger.log_scalar(value, key, iter_count)

            tb_logger.flush()

    # Finish training for one epoch
    return losses.avg, acc_top1.avg, acc_top5.avg