コード例 #1
0
def train2(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    mAP = AverageMeter("mAP", ":6.2f")
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5, mAP],
                             prefix="Epoch: [{}]".format(epoch))
    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    model.eval()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            len_images = len(images)
            for k in range(len(images)):
                images[k] = images[k].cuda(args.gpu, non_blocking=True)

        target = target.cuda(args.gpu, non_blocking=True)
        len_images = len(images)

        first_output = -1
        for k in range(len_images):
            # compute gradient and do SGD step
            optimizer.zero_grad()
            output = model(images[k])
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            losses.update(loss.item(), images[k].size(0))
            if k == 0:
                first_output = output

        images = images[0]
        output = first_output

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
コード例 #2
0
def update_network_multi(model, images, args, Memory_Bank, losses, top1, top5,
                         optimizer, criterion, mem_losses):
    # update network
    # negative logits: NxK
    image_size = len(images)
    q_list, k = model(im_q=images[1:image_size], im_k=images[0])
    k = concat_all_gather(k)
    l_pos_list = []
    for q in q_list:
        # l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_pos = torch.einsum('nc,ck->nk', [q, k.T])
        l_pos_list.append(l_pos)

    d_norm, d, l_neg_list = Memory_Bank(q_list)

    loss = 0
    cur_batch_size = l_pos_list[0].shape[0]
    cur_gpu = args.gpu
    choose_match = cur_gpu * cur_batch_size
    labels = torch.arange(choose_match,
                          choose_match + cur_batch_size,
                          dtype=torch.long).cuda()
    for k in range(len(l_pos_list)):
        logits = torch.cat([l_pos_list[k], l_neg_list[k]], dim=1)
        logits /= args.moco_t
        loss += criterion(logits, labels)
        if k == 0:
            # acc1/acc5 are (K+1)-way contrast classifier accuracy
            # measure accuracy and record loss
            acc1, acc5 = accuracy(logits, labels, topk=(1, 5))
            losses.update(loss.item(), images[0].size(0))
            top1.update(acc1.item(), images[0].size(0))
            top5.update(acc5.item(), images[0].size(0))
    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # update memory bank
    g_sum = 0
    with torch.no_grad():
        for k in range(len(l_pos_list)):
            logits = torch.cat([l_pos_list[k], l_neg_list[k]],
                               dim=1) / args.mem_t
            total_bsize = logits.shape[1] - args.cluster
            p_qd = nn.functional.softmax(logits, dim=1)[:, total_bsize:]  # n*k
            g = torch.einsum(
                'cn,nk->ck',
                [q_list[k].T, p_qd]) / logits.shape[0] - torch.mul(
                    torch.mean(torch.mul(p_qd, l_neg_list[k]), dim=0), d_norm)
            g_sum += -torch.div(g, torch.norm(d, dim=0)) / args.mem_t  # c*k
            if k == 0:
                logits = torch.softmax(logits, dim=1)
                batch_prob = torch.sum(logits[:, :logits.size(0)], dim=1)
                batch_prob = torch.mean(batch_prob)
                mem_losses.update(batch_prob.item(), logits.size(0))
        g_sum = all_reduce(g_sum) / torch.distributed.get_world_size()
        Memory_Bank.v.data = args.momentum * Memory_Bank.v.data + g_sum + args.mem_wd * Memory_Bank.W.data
        Memory_Bank.W.data = Memory_Bank.W.data - args.memory_lr * Memory_Bank.v.data
コード例 #3
0
def init_memory(train_loader, model, Memory_Bank, criterion, optimizer, epoch,
                args):

    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader), [losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))
    # switch to train mode
    model.train()
    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        if args.gpu is not None:
            for k in range(len(images)):
                images[k] = images[k].cuda(args.gpu, non_blocking=True)

        # compute output
        if args.multi_crop:
            q_list, k = model(im_q=images[0:-1], im_k=images[-1])
            q = q_list[0]
        elif not args.sym:
            q, k = model(im_q=images[0], im_k=images[1])
        else:
            q, _, _, k = model(im_q=images[0], im_k=images[1])
        d_norm, d, l_neg = Memory_Bank(q, init_mem=True)

        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # logits: Nx(1+K)

        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= args.moco_t
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        loss = criterion(logits, labels)
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = accuracy(logits, labels, topk=(1, 5))

        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1.item(), images[0].size(0))
        top5.update(acc5.item(), images[0].size(0))
        if i % args.print_freq == 0:
            progress.display(i)

        # fill the memory bank
        output = concat_all_gather(k)
        batch_size = output.size(0)
        start_point = i * batch_size
        end_point = min((i + 1) * batch_size, args.cluster)
        Memory_Bank.W.data[:, start_point:end_point] = output[:end_point -
                                                              start_point].T
        if (i + 1) * batch_size >= args.cluster:
            break
    for param_q, param_k in zip(model.module.encoder_q.parameters(),
                                model.module.encoder_k.parameters()):
        param_k.data.copy_(param_q.data)  # initialize
コード例 #4
0
def testing(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    mAP = AverageMeter("mAP", ":6.2f")
    progress = ProgressMeter(len(val_loader),
                             [batch_time, losses, top1, top5, mAP],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    correct_count = 0
    count_all = 0
    # implement our own random crop
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            target = target.cuda(args.gpu, non_blocking=True)
            output_list = []
            for image in images:
                output = model(image)
                output = torch.softmax(output, dim=1)
                output_list.append(output)
            output_list = torch.stack(output_list, dim=0)
            output_list, max_index = torch.max(output_list, dim=0)
            output = output_list
            images = images[0]
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            acc1 = torch.mean(concat_all_gather(acc1.unsqueeze(0)),
                              dim=0,
                              keepdim=True)
            acc5 = torch.mean(concat_all_gather(acc5.unsqueeze(0)),
                              dim=0,
                              keepdim=True)
            correct_count += float(acc1[0]) * images.size(0)
            count_all += images.size(0)
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            loss = criterion(output, target)
            losses.update(loss.item(), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(
            ' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '.
            format(top1=top1, top5=top5, mAP=mAP))
        final_accu = correct_count / count_all
        print("$$our final calculated accuracy %.7f" % final_accu)
    return top1.avg
コード例 #5
0
def update_network(model, images, args, Memory_Bank, losses, top1, top5,
                   optimizer, criterion, mem_losses):
    # update network
    # negative logits: NxK

    q, k = model(im_q=images[0], im_k=images[1])
    k = concat_all_gather(k)
    l_pos = torch.einsum('nc,ck->nk', [q, k.T])

    d_norm, d, l_neg = Memory_Bank(q)

    # logits: Nx(1+K)

    logits = torch.cat([l_pos, l_neg], dim=1)
    logits /= args.moco_t

    cur_batch_size = logits.shape[0]
    cur_gpu = args.gpu
    choose_match = cur_gpu * cur_batch_size
    labels = torch.arange(choose_match,
                          choose_match + cur_batch_size,
                          dtype=torch.long).cuda()
    total_bsize = logits.shape[1] - args.cluster
    loss = criterion(logits, labels)

    # acc1/acc5 are (K+1)-way contrast classifier accuracy
    # measure accuracy and record loss
    acc1, acc5 = accuracy(logits, labels, topk=(1, 5))
    losses.update(loss.item(), images[0].size(0))
    top1.update(acc1.item(), images[0].size(0))
    top5.update(acc5.item(), images[0].size(0))

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # update memory bank
    with torch.no_grad():
        logits = torch.cat([l_pos, l_neg], dim=1) / args.mem_t
        p_qd = nn.functional.softmax(logits, dim=1)[:, total_bsize:]
        g = torch.einsum('cn,nk->ck',
                         [q.T, p_qd]) / logits.shape[0] - torch.mul(
                             torch.mean(torch.mul(p_qd, l_neg), dim=0), d_norm)
        g = -torch.div(g, torch.norm(d, dim=0)) / args.mem_t  # c*k
        g = all_reduce(g) / torch.distributed.get_world_size()
        Memory_Bank.v.data = args.momentum * Memory_Bank.v.data + g + args.mem_wd * Memory_Bank.W.data
        Memory_Bank.W.data = Memory_Bank.W.data - args.memory_lr * Memory_Bank.v.data
    logits = torch.softmax(logits, dim=1)
    batch_prob = torch.sum(logits[:, :logits.size(0)], dim=1)
    batch_prob = torch.mean(batch_prob)
    mem_losses.update(batch_prob.item(), logits.size(0))
    return l_neg, logits
コード例 #6
0
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    mAP = AverageMeter("mAP", ":6.2f")
    progress = ProgressMeter(len(val_loader),
                             [batch_time, losses, top1, top5, mAP],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            target = target.cuda(args.gpu, non_blocking=True)
            output = model(images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            acc1 = torch.mean(concat_all_gather(acc1.unsqueeze(0)),
                              dim=0,
                              keepdim=True)
            acc5 = torch.mean(concat_all_gather(acc5.unsqueeze(0)),
                              dim=0,
                              keepdim=True)
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            loss = criterion(output, target)
            losses.update(loss.item(), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(
            ' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '.
            format(top1=top1, top5=top5, mAP=mAP))

    return top1.avg
コード例 #7
0
def update_sym_network(model, images, args, Memory_Bank, losses, top1, top5,
                       optimizer, criterion, mem_losses):
    # update network
    # negative logits: NxK
    model.zero_grad()
    q_pred, k_pred, q, k = model(im_q=images[0], im_k=images[1])
    q = concat_all_gather(q)
    k = concat_all_gather(k)
    l_pos1 = torch.einsum('nc,ck->nk', [q_pred, k.T])
    l_pos2 = torch.einsum('nc,ck->nk', [k_pred, q.T])

    d_norm1, d1, l_neg1 = Memory_Bank(q_pred)
    d_norm2, d2, l_neg2 = Memory_Bank(k_pred)
    # logits: Nx(1+K)

    logits1 = torch.cat([l_pos1, l_neg1], dim=1)
    logits1 /= args.moco_t
    logits2 = torch.cat([l_pos2, l_neg2], dim=1)
    logits2 /= args.moco_t

    cur_batch_size = logits1.shape[0]
    cur_gpu = args.gpu
    choose_match = cur_gpu * cur_batch_size
    labels = torch.arange(choose_match,
                          choose_match + cur_batch_size,
                          dtype=torch.long).cuda()

    loss = 0.5 * criterion(logits1, labels) + 0.5 * criterion(logits2, labels)

    # acc1/acc5 are (K+1)-way contrast classifier accuracy
    # measure accuracy and record loss
    acc1, acc5 = accuracy(logits1, labels, topk=(1, 5))
    losses.update(loss.item(), images[0].size(0))
    top1.update(acc1.item(), images[0].size(0))
    top5.update(acc5.item(), images[0].size(0))
    acc1, acc5 = accuracy(logits2, labels, topk=(1, 5))
    losses.update(loss.item(), images[0].size(0))
    top1.update(acc1.item(), images[0].size(0))
    top5.update(acc5.item(), images[0].size(0))

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # update memory bank
    with torch.no_grad():
        # update memory bank

        # logits: Nx(1+K)

        logits1 = torch.cat([l_pos1, l_neg1], dim=1)
        logits1 /= args.mem_t
        # negative logits: NxK
        # logits: Nx(1+K)

        logits2 = torch.cat([l_pos2, l_neg2], dim=1)
        logits2 /= args.mem_t
        total_bsize = logits1.shape[1] - args.cluster
        p_qd1 = nn.functional.softmax(logits1, dim=1)[:, total_bsize:]
        g1 = torch.einsum(
            'cn,nk->ck', [q_pred.T, p_qd1]) / logits1.shape[0] - torch.mul(
                torch.mean(torch.mul(p_qd1, l_neg1), dim=0), d_norm1)
        p_qd2 = nn.functional.softmax(logits2, dim=1)[:, total_bsize:]
        g2 = torch.einsum(
            'cn,nk->ck', [k_pred.T, p_qd2]) / logits2.shape[0] - torch.mul(
                torch.mean(torch.mul(p_qd2, l_neg2), dim=0), d_norm1)
        g = -0.5 * torch.div(
            g1, torch.norm(d1, dim=0)) / args.mem_t - 0.5 * torch.div(
                g2, torch.norm(d1, dim=0)) / args.mem_t  # c*k
        g = all_reduce(g) / torch.distributed.get_world_size()
        Memory_Bank.v.data = args.momentum * Memory_Bank.v.data + g + args.mem_wd * Memory_Bank.W.data
        Memory_Bank.W.data = Memory_Bank.W.data - args.memory_lr * Memory_Bank.v.data
        logits1 = torch.softmax(logits1, dim=1)
        batch_prob1 = torch.sum(logits1[:, :logits1.size(0)], dim=1)
        logits2 = torch.softmax(logits2, dim=1)
        batch_prob2 = torch.sum(logits2[:, :logits2.size(0)], dim=1)
        batch_prob = 0.5 * torch.mean(batch_prob1) + 0.5 * torch.mean(
            batch_prob2)
        mem_losses.update(batch_prob.item(), logits1.size(0))
    return l_neg1, logits1
コード例 #8
0
def testing2(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    mAP = AverageMeter("mAP", ":6.2f")
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5, mAP],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    # if args.dataset == "VOC2007":
    #    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                     std=[0.229, 0.224, 0.225])
    #    transformations_valid = transforms.Compose([
    #        transforms.FiveCrop(224),
    #    ])
    correct_count = 0
    count_all = 0
    # implement our own random crop
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            # if args.dataset == "VOC2007":
            #    images = transformations_valid(images)

            # print(images.size())
            if args.gpu is not None and args.dataset != "VOC2007":
                for k in range(len(images)):
                    images[k] = images[k].cuda(args.gpu, non_blocking=True)

            target = target.cuda(args.gpu, non_blocking=True)
            output_list = []
            for image in images:
                output = model(image)
                output = torch.softmax(output, dim=1)
                output_list.append(output)
            output_list = torch.stack(output_list, dim=0)
            output_list = torch.mean(output_list, dim=0)
            output = output_list
            images = images[0]
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            acc1 = torch.mean(concat_all_gather(acc1), dim=0, keepdim=True)
            acc5 = torch.mean(concat_all_gather(acc5), dim=0, keepdim=True)
            correct_count += float(acc1[0]) * images.size(0)
            count_all += images.size(0)
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            loss = criterion(output, target)
            losses.update(loss.item(), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '
              .format(top1=top1, top5=top5, mAP=mAP))
        final_accu = correct_count / count_all
        print("$$our final average accuracy %.7f" % final_accu)
    return top1.avg
コード例 #9
0
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    mAP = AverageMeter("mAP", ":6.2f")
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5, mAP],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    # if args.dataset == "VOC2007":
    #    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                     std=[0.229, 0.224, 0.225])
    #    transformations_valid = transforms.Compose([
    #        transforms.FiveCrop(224),
    #    ])
    # implement our own random crop
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            # print(images.size())
            if args.gpu is not None and args.dataset != "VOC2007":
                # if args.add_crop:
                #     for k in range(len(images)):
                #         images[k] = images[k].cuda(args.gpu, non_blocking=True)
                # else:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            # if args.add_crop>1:
            #     output_list=[]
            #     for image in images:
            #         output = model(image)
            #         output_list.append(output)
            #     output_list=torch.stack(output_list,dim=0)
            #     output_list=torch.mean(output_list,dim=0)
            #     output=output_list
            #     images=images[0]
            # else:
            # compute output
            output = model(images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            acc1 = torch.mean(concat_all_gather(acc1), dim=0, keepdim=True)
            acc5 = torch.mean(concat_all_gather(acc5), dim=0, keepdim=True)
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            loss = criterion(output, target)
            losses.update(loss.item(), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '
              .format(top1=top1, top5=top5, mAP=mAP))

    return top1.avg
コード例 #10
0
def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    mAP = AverageMeter("mAP", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5, mAP],
        prefix="Epoch: [{}]".format(epoch))

    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    model.eval()
    batch_total = len(train_loader)
    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        # adjust_batch_learning_rate(optimizer, epoch, i, batch_total, args)

        if args.gpu is not None:
            # if args.add_crop :
            #     len_images = len(images)
            #     for k in range(len(images)):
            #         images[k] = images[k].cuda(args.gpu, non_blocking=True)
            #
            # else:
            images = images.cuda(args.gpu, non_blocking=True)

        # if args.add_crop:
        #     target = target.cuda(args.gpu, non_blocking=True)
        #     len_images = len(images)
        #     loss=0
        #     first_output=-1
        #
        #     for k in range(len_images):
        #         output=model(images[k])
        #         loss += criterion(output, target)
        #         if k==0:
        #             first_output=output
        #         if epoch == 0 and i == 0:
        #             print("%d/%d loss values %.5f" %(k,len_images, loss.item()))
        #     loss/=len_images
        #     images = images[0]
        #     output=first_output
        # else:
        #
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.sgdr != 0:
            lr_scheduler.step(epoch + i / batch_total)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
コード例 #11
0
ファイル: train.py プロジェクト: maodeshu/CLSA
def train(train_loader, model, criterion, optimizer, epoch, args,log_path):
    """
    :param train_loader:  data loader
    :param model: training model
    :param criterion: loss function
    :param optimizer: SGD optimizer
    :param epoch: current epoch
    :param args: config parameter
    :return:
    """
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    mse_criterion=nn.MSELoss().cuda(args.gpu)
    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            len_images = len(images)
            for k in range(len(images)):
                images[k] = images[k].cuda(args.gpu, non_blocking=True)
            crop_copy_length = int((len_images - 1) / 2)
            image_k = images[0]
            image_q = images[1:1 + crop_copy_length]
            image_strong = images[1 + crop_copy_length:]

        output, target, output2, target2 = model(image_q, image_k, image_strong)
        loss_contrastive = 0
        loss_weak_strong = 0
        if epoch == 0 and i == 0:
            print("-" * 100)
            print("contrastive loss count %d" % len(output))
            print("weak strong loss count %d" % len(output2))
            print("-" * 100)
        for k in range(len(output)):
            loss1 = criterion(output[k], target[k])
            loss_contrastive += loss1
        for k in range(len(output2)):
            loss2 = -torch.mean(torch.sum(torch.log(output2[k]) * target2[k], dim=1))  # DDM loss
            loss_weak_strong += loss2
        loss = loss_contrastive + args.alpha * loss_weak_strong
        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output[0], target[0], topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            progress.write_record(i,log_path)
    return top1.avg