コード例 #1
0
ファイル: train.py プロジェクト: lee-zq/3DUNet-Pytorch
def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
    print("=======Epoch:{}=======lr:{}".format(
        epoch,
        optimizer.state_dict()['param_groups'][0]['lr']))
    model.train()
    train_loss = metrics.LossAverage()
    train_dice = metrics.DiceAverage(n_labels)

    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data, target = data.float(), target.long()
        target = common.to_one_hot_3d(target, n_labels)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss0 = loss_func(output[0], target)
        loss1 = loss_func(output[1], target)
        loss2 = loss_func(output[2], target)
        loss3 = loss_func(output[3], target)

        loss = loss3 + alpha * (loss0 + loss1 + loss2)
        loss.backward()
        optimizer.step()

        train_loss.update(loss3.item(), data.size(0))
        train_dice.update(output[3], target)

    val_log = OrderedDict({
        'Train_Loss': train_loss.avg,
        'Train_dice_liver': train_dice.avg[1]
    })
    if n_labels == 3: val_log.update({'Train_dice_tumor': train_dice.avg[2]})
    return val_log
コード例 #2
0
def train(model, optimizer, criterion, epoch, print_freq, data_loader,
          data_pth):
    # model.train()
    losses = AverageMeter()
    is_add_margin = False
    labels_count_in_one_batch = 8
    labels_count_for_all = 40
    iteration_count = 40

    start = time.time()
    for _j in range(iteration_count):
        for i in range(int(labels_count_for_all / labels_count_in_one_batch)):
            start_label = i * labels_count_in_one_batch + 1
            model.eval()
            feat1, feat2, feat3 = _get_input_samples(model, [
                x for x in range(start_label, start_label +
                                 labels_count_in_one_batch)
            ], data_pth)

            model.train()
            loss = criterion(torch.stack(feat1), torch.stack(feat2),
                             torch.stack(feat3))
            optimizer.zero_grad()
            # backward
            loss.backward()
            optimizer.step()
            losses.update(loss.item())

            if (_j * int(labels_count_for_all / labels_count_in_one_batch) +
                    i + 1) % print_freq == 0:
                print(
                    'Epoch: [{}][{}/{}]\t'
                    'Loss {:.6f} ({:.6f})\t'.format(
                        epoch,
                        _j *
                        int(labels_count_for_all / labels_count_in_one_batch) +
                        i + 1,
                        iteration_count *
                        int(labels_count_for_all / labels_count_in_one_batch),
                        losses.val, losses.mean))
            if losses.val < 1e-5:
                is_add_margin = True

    param_group = optimizer.param_groups
    print('Epoch: [{}]\tEpoch Time {:.1f} s\tLoss {:.6f}\t'
          'Lr {:.2e}'.format(epoch, (time.time() - start), losses.mean,
                             param_group[0]['lr']))
    return is_add_margin
コード例 #3
0
def train(train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          scheduler,
          mixup=False):
    batch_time = tools.AverageMeter('Time', ':6.3f')
    losses = tools.AverageMeter('Loss', ':.4e')
    top1 = tools.AverageMeter('Acc@1', ':2.2f')
    top5 = tools.AverageMeter('Acc@5', ':2.2f')
    progress = tools.ProgressMeter(len(train_loader), batch_time, losses, top1,
                                   top5)
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        scheduler.step(epoch + i / len(train_loader))

        optimizer.zero_grad()
        if mixup == False:
            output = model(input)
            loss = criterion(output, target)
        else:
            output, loss = tools.cutmix(input, target, model, criterion)

        acc1, acc5 = tools.accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))

        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            progress.pr2int(i)
    return top1.avg.data.cpu().numpy(), losses.avg
コード例 #4
0
ファイル: solver.py プロジェクト: upzheng/Keratoconus
    def train_iter(self, step, dataloader):

        imgs, target = dataloader.next()

        # Train mode
        self.model.train()

        imgs = imgs.float()
        imgs, target = Variable(imgs).cuda(), Variable(target).cuda()
        score = self.model(imgs)
        loss = self.loss(score, target)

        # Backward
        loss.backward()
        self.optim.step()
        self.model.zero_grad()

        if step % self.args.display_freq == 0:

            # compute metrics
            # TODO more general metrics
            acc = self.accuracy(score, target)
            recall, prec, f1, kap = self.metrics_mc(score, target)

            print('Training - Loss: {:.4f} - Acc: {:.4f} - Precision: {:.4f} - Recall: {:.4f} - f1score:{:.4f} - kappa:{:.4f}' \
                  .format(loss.data[0], acc, prec, recall, f1, kap))

            # Record to tensorboard
            # TODO more general metrics
            scalars = [loss.data[0], acc, prec, recall, f1, kap]
            names = ['loss', 'acc', 'precision', 'recall', 'f1score', 'kappa']
            write_scalars(self.writer, scalars, names, step, 'train')

            # debug info
            if self.args.debug:
                print('lebel: {}'.format(target.cpu().data.tolist()))
                print('pred : {}'.format(score.max(1)[1].cpu().data.tolist()))

        del imgs, score, target, loss
コード例 #5
0
def train(model, train_loader, optimizer, criterion, n_labels):
    print("=======Epoch:{}=======lr:{}".format(
        epoch,
        optimizer.state_dict()['param_groups'][0]['lr']))
    model.train()
    train_loss = metrics.LossAverage()
    train_dice = metrics.DiceAverage(n_labels)

    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data, target = data.float(), target.long()
        target = common.to_one_hot_3d(target, n_labels)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        # if idx==0:
        #     print(output.shape)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), data.size(0))
        train_dice.update(output, target)

    if n_labels == 2:
        return OrderedDict({
            'Train Loss': train_loss.avg,
            'Train dice0': train_dice.avg[0],
            'Train dice1': train_dice.avg[1]
        })
    else:
        return OrderedDict({
            'Train Loss': train_loss.avg,
            'Train dice0': train_dice.avg[0],
            'Train dice1': train_dice.avg[1],
            'Train dice2': train_dice.avg[2]
        })
コード例 #6
0
def main():
    # get unity environment
    env, brain = get_unity_envs()

    # get arguments
    args = get_arguments()
    print(args)

    # set gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cudnn.enabled = True
    cudnn.benchmark = True
    cuda = torch.cuda.is_available()

    # set random seed
    rn = set_seeds(args.random_seed, cuda)

    # make directory
    os.makedirs(args.snapshot_dir, exist_ok=True)

    # get validation dataset
    val_set = get_validation_dataset(args)
    print("len of test set: ", len(val_set))
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.real_batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    # generate training list
    with open(args.syn_list_path, "w") as fp:
        for i in range(args.syn_img_num):
            if i % 10 != 0:
                fp.write(str(i + 1) + '\n')

    # get main model
    main_model = MLP(args.num_inputs, args.num_outputs, args.hidden_size)
    if args.resume != "":
        main_model.load_state_dict(torch.load(args.resume))

    # get task model
    if args.task_model_name == "FCN8s":
        task_model = FCN8s_sourceonly(n_class=args.num_classes)
        vgg16 = VGG16(pretrained=True)
        task_model.copy_params_from_vgg16(vgg16)
    else:
        raise ValueError("Specified model name: FCN8s")

    # save initial task model
    torch.save(task_model.state_dict(),
               os.path.join(args.snapshot_dir, "task_model_init.pth"))

    if cuda:
        main_model = main_model.cuda()
        task_model = task_model.cuda()

    # get optimizer
    main_optimizer = optim.Adam(main_model.parameters(), lr=args.main_lr)
    task_optimizer = optim.SGD(task_model.parameters(),
                               lr=args.task_lr,
                               momentum=0.9,
                               weight_decay=1e-4)

    frame_idx = 0
    whole_start_time = time.time()
    while frame_idx < args.max_frames:

        log_probs = []
        rewards = []

        start_time = time.time()

        for i_step in range(1, args.step_each_frame + 1):

            # get initial attribute list
            state = np.random.rand(1, args.num_inputs)
            state = torch.from_numpy(state).float()

            if cuda:
                state = state.cuda()

            # get modified attribute list
            dist = main_model(state)
            action = dist.sample()

            action_actual = action.float() / 10.0  # [0, 0.9]

            # generate images by attribute list
            print("action: " + str(action_actual.cpu().numpy()))
            get_images_by_attributes(args, i_step, env, brain,
                                     action_actual[0].cpu().numpy())

            train_set = get_training_dataset(args, i_step)
            train_loader = data.DataLoader(train_set,
                                           batch_size=args.syn_batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           pin_memory=True)

            # train the task model using synthetic dataset
            task_model.load_state_dict(
                torch.load(
                    os.path.join(args.snapshot_dir, "task_model_init.pth")))

            reward = train_task_model(train_loader, val_loader, task_model,
                                      task_optimizer, args, cuda)
            log_prob = dist.log_prob(action)[0]

            log_probs.append(log_prob)
            rewards.append(torch.FloatTensor([reward]))

            frame_idx += 1

            if frame_idx == 1:
                moving_start = torch.FloatTensor([reward])

        baseline = compute_returns(rewards, moving_start)
        moving_start = baseline[-1]

        log_probs = torch.cat(log_probs)
        baseline = torch.cat(baseline).detach()
        rewards = torch.cat(rewards).detach()

        advantage = rewards - baseline
        if cuda:
            advantage = advantage.cuda()

        loss = -(log_probs * advantage.detach()).mean()

        with open(os.path.join(args.snapshot_dir, "logs.txt"), 'a') as fp:
            fp.write(
                "frame idx: {0:4d}, state: {1:s}, action: {2:s}, reward: {3:s}, baseline: {4:s}, loss: {5:.2f} \n"
                .format(frame_idx, str(state.cpu()[0].numpy()),
                        str(action.cpu()[0].numpy()), str(rewards.numpy()),
                        str(baseline.numpy()), loss.item()))

        print("optimize the main model parameters")
        main_optimizer.zero_grad()
        loss.backward()
        main_optimizer.step()

        elapsed_time = time.time() - start_time
        print("[frame: {0:3d}], [loss: {1:.2f}], [time: {2:.1f}]".format(
            frame_idx, loss.item(), elapsed_time))

        torch.save(
            main_model.state_dict(),
            os.path.join(args.snapshot_dir, "main_model_%d.pth" % frame_idx))

    elapsed_time = time.time() - whole_start_time
    print("whole time: {0:.1f}".format(elapsed_time))
    env.close()
コード例 #7
0
ファイル: main.py プロジェクト: mikochou/ECG-tianchi
def train(train_loader, val_loader, model, optimizer, criterion, lr_scheduler,
          device, opt):
    '''
    model training
    :param: train_loader: dataloader, val_loader: dataloader, model: cpkt,
    optimizer: optimizer, criterion: weighted_binary_crossentropy, lr_scheduler: LRScheduler,
    device: device, opt: dict
    :return
    '''
    total_step = len(train_loader)
    best_acc = -1
    losses = AverageMeter()
    batch_time = AverageMeter()
    end = time.time()
    print_freq = 20
    iter_per_epoch = len(train_loader)
    iter_sum = iter_per_epoch * opt.epochs
    fast_train = hasattr(opt, 'fast_train')
    writer = SummaryWriter(opt.model_save_path)

    for epoch in range(opt.epochs):
        model.train()

        prefetcher = data_prefetcher(train_loader)
        datas, ages, sexs, labels = prefetcher.next()
        i = 0
        while datas is not None:
            i += 1
            lr_scheduler.update(i, epoch)
            # for i, (datas, ages, sexs, labels) in enumerate(train_loader):
            datas = datas.to(device)
            ages = ages.to(device)
            sexs = sexs.to(device)
            labels = labels.to(device)
            # Forward pass
            outputs = model(datas, ages, sexs)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Update tensorboard
            batch_time.update(time.time() - end)
            losses.update(loss.item(), datas.size(0))
            # Print
            if (i + 1) % print_freq == 0:
                iter_used = epoch * iter_per_epoch + i
                used_time = batch_time.sum
                total_time = used_time / iter_used * iter_sum
                used_time = str(datetime.timedelta(seconds=used_time))
                total_time = str(datetime.timedelta(seconds=total_time))
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, LR:{:.5f}, Time[{:.7s}/{:.7s}]'
                    .format(epoch + 1, opt.epochs, i + 1, total_step,
                            loss.item(), optimizer.param_groups[0]['lr'],
                            used_time, total_time),
                    flush=True)
                writer.add_scalar('Learning_rate',
                                  optimizer.param_groups[0]['lr'], iter_used)
                writer.add_scalar('Train/Avg_Loss', losses.avg, iter_used)
            end = time.time()
            datas, ages, sexs, labels = prefetcher.next()

        if not fast_train:
            # acc in train set
            acc_train = val(train_loader, model, device)
            print('Train Accuracy: {} %'.format(acc_train), flush=True)
            writer.add_scalar('Train/F1_Score', acc_train, iter_used)
            # acc in validation set
            acc_val = val(val_loader, model, device)
            if acc_val > best_acc:
                # Save the model checkpoint
                best_acc = acc_val
                if epoch > int(opt.epochs * 0.8):
                    save_name = args.model + '_e{}.ckpt'.format(epoch)
                    save_path = opt.model_save_path + save_name
                    torch.save(model.state_dict(), save_path)
            print('Validation Accuracy: {} %'.format(acc_val), flush=True)
            writer.add_scalar('Validation/F1_Score', acc_val, iter_used)
        else:
            if epoch > int(opt.epochs * 0.8):
                acc_val = val(val_loader, model, device)
                if acc_val > best_acc:
                    best_acc = acc_val
                    save_name = args.model + '_e{}.ckpt'.format(epoch)
                    save_path = opt.model_save_path + save_name
                    torch.save(model.state_dict(), save_path)
    return
コード例 #8
0
def train(data_loader, model, criterion, optim, epoch, snap_shot):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    dice_cof = AverageMeter()

    if args.fine_tune:
        model.eval()
    else:
        model.train()

    end = time.time()
    for i, (inputs, label) in enumerate(data_loader):

        data_time.update(time.time() - end)

        label = torch.autograd.Variable(label['mask'])
        input_var = inputs.cuda(async=True)
        target_var = label.cuda(async=True)
        output = model.forward(input_var)
        loss = criterion(output, target_var)

        score = dice_overall_v2(output, target_var)

        loss = loss / args.accumulate_step
        loss.backward()

        if (i + 1) % args.accumulate_step == 0:
            optim.step()
            optim.zero_grad()

        losses.update(loss.data, inputs.size(0))
        dice_cof.update(score.data, inputs.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print(
                ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                 'Dice_coff {dice_cof.val:.3f} ({dice_cof.avg:.3f})\t'.format(
                     epoch,
                     i,
                     len(data_loader),
                     batch_time=batch_time,
                     data_time=data_time,
                     loss=losses,
                     dice_cof=dice_cof,
                     lr=optim.param_groups[-1]['lr'])))

    if (epoch + 1) % args.val_freq == 0 or epoch == args.epochs - 1:

        template = 'Epoch: {}, Loss {:.2f}, Dice_cof@1 {:.3f}\n'
        if args.use_total:
            with open(
                    os.path.join(
                        'logdir', args.model,
                        'train_txt_%d_%d.txt' % (snap_shot, args.image_size)),
                    'a') as f:
                f.writelines(template.format(epoch, losses.avg, dice_cof.avg))
        else:
            with open(
                    os.path.join(
                        'logdir', args.model, str(args.split),
                        'train_txt_%d_%d.txt' % (snap_shot, args.image_size)),
                    'a') as f:
                f.writelines(template.format(epoch, losses.avg, dice_cof.avg))
コード例 #9
0
ファイル: triplet_train.py プロジェクト: Xavierxhq/tableware
def train(model, optimizer, criterion, epoch, print_freq, data_loader, data_pth):
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    start = time.time()
    evaluator = Evaluator(model)
    distmat = evaluator.calDistmat(data_loader)
    is_add_margin = False

    # _t1 = time.time()
    # _get_avg_feature_for_labels(model, [x for x in range(1, 41)], data_pth)
    # print('time for getting center: %.1f s' % (time.time() - _t1))

    # model.train()
    prefix = None
    for i, inputs in enumerate(data_loader):
        data_time.update(time.time() - start)

        # model optimizer
        # parse data
        imgs, pids = inputs
        labels = [x.item() for x in pids]

        if i % 1 == 0:
            # clean_cache(prefix)
            _t1 = time.time()
            prefix = _get_avg_feature_for_labels(model, [x for x in range(1, 41)], data_pth)
            # print('time for getting center: %.1f s' % (time.time() - _t1), i)

        img1, img2, img3 = triplet_example(imgs, pids, distmat)
        # input1 = img1.cuda()
        input1 = get_center_anchor(labels, prefix)
        feat1 = torch.stack(input1)
        # print(input1)
        if random.randint(1, 10) > 5:
            input2 = get_furthest_from_center(labels, prefix)
            # input2 = get_nearest_to_center(labels, prefix)
            feat2 = torch.stack(input2)
        else:
            feat2 = img2.cuda()
            feat2 = model(feat2)
        input3 = img3.cuda()

        # forward
        # feat1 = model(input1)
        # feat2 = model(input2)
        feat3 = model(input3)
        # print(feat1)
        # print(feat3)

        loss = criterion(feat1, feat2, feat3)

        optimizer.zero_grad()
        # backward
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - start)
        losses.update(loss.item())

        start = time.time()

        if (i + 1) % print_freq == 0:
            print('Epoch: [{}][{}/{}]\t'
                  'Batch Time {:.3f} ({:.3f})\t'
                    'Data Time {:.3f} ({:.3f})\t'
                    'Loss {:.6f} ({:.6f})\t'
                      .format(epoch, i + 1, len(data_loader),
                              batch_time.val, batch_time.mean,
                              data_time.val, data_time.mean,
                              losses.val, losses.mean))
        if losses.val < 1e-5:
            is_add_margin = True

    param_group = optimizer.param_groups
    print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.6f}\t'
              'Lr {:.2e}'
              .format(epoch, batch_time.sum, losses.mean, param_group[0]['lr']))
    print()
    return is_add_margin
コード例 #10
0
ファイル: train.py プロジェクト: luyuzhe111/Glomeruli
def train(train_loader, model, optimizer, criterion, device, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()
    top1 = AverageMeter()
    tbar = tqdm(train_loader, desc='\r')

    model.train()
    preds = []
    gts = []
    paths = []
    for batch_idx, (inputs, targets, image_path) in enumerate(tbar):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs, targets = inputs.to(device), targets.to(device)

        r = np.random.rand(1)
        if args.circlemix_prob > r:
            rand_index = torch.randperm(inputs.size()[0])
            target_a = targets
            target_b = targets[rand_index]

            r1 = np.random.randint(0, 360)
            r2 = np.random.randint(0, 360)
            start, end = min(r1, r2), max(r1, r2)
            lam = (end - start) / 360

            height = inputs.shape[2]
            width = inputs.shape[3]

            assert height == width, 'height does not equal to width'
            side = height

            mask = np.zeros((side, side), np.uint8)
            vertices = polygon_vertices(side, start, end)

            roi_mask = cv2.fillPoly(mask, np.array([vertices]), 255)
            roi_mask_rgb = np.repeat(roi_mask[np.newaxis, :, :],
                                     inputs.shape[1],
                                     axis=0)
            roi_mask_batch = np.repeat(roi_mask_rgb[np.newaxis, :, :, :],
                                       inputs.shape[0],
                                       axis=0)
            roi_mask_batch = torch.from_numpy(roi_mask_batch)

            roi_mask_batch = roi_mask_batch.to(device)
            rand_index = rand_index.to(device)

            inputs2 = inputs[rand_index].clone()
            inputs[roi_mask_batch > 0] = inputs2[roi_mask_batch > 0]

            outputs = model(inputs)
            loss = criterion(outputs, target_a) * (1. - lam) + criterion(
                outputs, target_b) * lam

            if args.optimizer == 'SAM':
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                (criterion(model(inputs), target_a) * (1. - lam) +
                 criterion(model(inputs), target_b) * lam).backward()
                optimizer.second_step(zero_grad=True)
            else:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        elif args.cutmix_prob > r:
            lam = np.random.beta(args.beta, args.beta)
            rand_index = torch.randperm(inputs.size()[0]).to(device)
            target_a = targets
            target_b = targets[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :,
                                                        bbx1:bbx2, bby1:bby2]

            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                       (inputs.size()[-1] * inputs.size()[-2]))

            outputs = model(inputs)
            loss = criterion(outputs, target_a) * lam + criterion(
                outputs, target_b) * (1. - lam)

            if args.optimizer == 'SAM':
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                (criterion(model(inputs), target_a) * (1. - lam) +
                 criterion(model(inputs), target_b) * lam).backward()
                optimizer.second_step(zero_grad=True)
            else:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        elif args.cutout_prob > r:
            lam = np.random.beta(args.beta, args.beta)
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[:, :, bbx1:bbx2, bby1:bby2] = 0

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            if args.optimizer == 'SAM':
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                criterion(model(inputs), targets).backward()
                optimizer.second_step(zero_grad=True)
            else:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        else:
            inputs = inputs.float()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            if args.optimizer == 'SAM':
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                criterion(
                    model(inputs),
                    targets).backward()  # make sure to do a full forward pass
                optimizer.second_step(zero_grad=True)
            else:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        score = F.softmax(outputs, dim=1)
        pred = score.data.max(1)[1]

        preds.extend(pred.tolist())
        gts.extend(targets.tolist())
        paths.extend(image_path)

        [
            acc1,
        ] = accuracy(outputs, targets, topk=(1, ))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        tbar.set_description('\r Train Loss: %.3f | Top1: %.3f' %
                             (losses.avg, top1.avg))

    f1s = list(f1_score(gts, preds, average=None))
    f1_avg = sum(f1s) / len(f1s)

    return losses.avg, top1.avg, f1_avg, f1s