示例#1
0
文件: train_v2v.py 项目: zyg11/VKD
def main():
    conf = Conf()
    args = parse(conf)
    device = conf.get_device()

    conf.suppress_random(set_determinism=args.set_determinism)
    saver = Saver(conf.log_path, args.exp_name)

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    num_pids = train_loader.dataset.get_num_pids()

    net = nn.DataParallel(get_model(args, num_pids))
    net = net.to(device)

    saver.write_logs(net.module, vars(args))

    opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd)
    milestones = list(
        range(args.first_milestone, args.num_epochs, args.step_milestone))
    scheduler = lr_scheduler.MultiStepLR(opt,
                                         milestones=milestones,
                                         gamma=args.gamma)

    triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device)
    class_loss = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("EXP_NAME: ", args.exp_name)

    for e in range(args.num_epochs):

        if e % args.eval_epoch_interval == 0 and e > 0:
            ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                           galleryimg_loader, DATA_CONFS[args.dataset_name],
                           device)
            ev.eval(saver, e, args.verbose)

        if e % args.save_epoch_interval == 0 and e > 0:
            saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}')

        avm = AvgMeter(['triplet', 'class'])

        for it, (x, y, cams) in enumerate(train_loader):
            net.train()

            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            embeddings, f_class = net(x, return_logits=True)

            triplet_loss_batch = triplet_loss(embeddings, y)
            class_loss_batch = class_loss(f_class, y)
            loss = triplet_loss_batch + class_loss_batch

            avm.add([triplet_loss_batch.item(), class_loss_batch.item()])

            loss.backward()
            opt.step()

        if e % args.print_epoch_interval == 0:
            stats = avm()
            str_ = f"Epoch: {e}"
            for (l, m) in stats:
                str_ += f" - {l} {m:.2f}"
                saver.dump_metric_tb(m, e, 'losses', f"avg_{l}")
            saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr')
            print(str_)

        scheduler.step()

    ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                   galleryimg_loader, DATA_CONFS[args.dataset_name], device)
    ev.eval(saver, e, args.verbose)

    saver.save_net(net.module, 'chk_end')
    saver.writer.close()
示例#2
0
                   DATA_CONFS[args.dataset_name], device)

    print('v' * 100)
    ev.eval(saver=None, iteration=None, verbose=True, do_tb=False)
    print('v' * 100)

    student_net.reinit_layers(args.reinit_l4, args.reinit_l3)

    saver = Saver(conf.log_path, args.exp_name)
    saver.write_logs(student_net, vars(args))

    d_trainer: DistillationTrainer = DistillationTrainer(train_loader, query_loader,
                                   gallery_loader, queryimg_loader, galleryimg_loader, conf.get_device(),
                                   saver, args, conf)

    print("EXP_NAME: ", args.exp_name)

    for idx_iteration in range(args.num_generations):
        print(f'starting generation {idx_iteration+1}')
        print('#'*100)
        teacher_net = d_trainer(teacher_net, student_net)
        d_trainer.evaluate(teacher_net)
        teacher_net.teacher_mode()

        student_net = deepcopy(teacher_net)
        saver.save_net(student_net, f'chk_di_{idx_iteration + 1}')

        student_net.reinit_layers(args.reinit_l4, args.reinit_l3)

    saver.writer.close()