示例#1
0
     sorted({min(net.depth_list), max(net.depth_list)}),
     'pixelshuffle_depth_list':
     sorted({
         min(net.pixelshuffle_depth_list),
         max(net.pixelshuffle_depth_list)
     }),
 }
 if args.task == 'kernel':
     validate_func_dict['ks_list'] = sorted(args.ks_list)
     if run_manager.start_epoch == 0:
         # model_path = download_url('https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D4_E6_K7',
         #                           model_dir='.torch/ofa_checkpoints/%d' % hvd.rank())
         model_path = './exp/sr_bn_mse_normal2pixelshuffle/checkpoint/model_best.pth.tar'  #################### 필요에 맞춰서 바꿔줘야함
         load_models(run_manager, run_manager.net, model_path=model_path)
         run_manager.write_log(
             '%.3f\t%.3f\t%s' % validate(run_manager, **validate_func_dict),
             'valid')
     train(
         run_manager, args, lambda _run_manager, epoch, is_test: validate(
             _run_manager, epoch, is_test, **validate_func_dict))
 elif args.task == 'depth':
     from ofa.elastic_nn.training.progressive_shrinking import supporting_elastic_depth
     # 해당함수가서 init model path 조정해줘야함 필요할때마다
     supporting_elastic_depth(train, run_manager, args, validate_func_dict)
 elif args.task == 'expand':
     from ofa.elastic_nn.training.progressive_shrinking import supporting_elastic_expand
     # 해당함수가서 init model path 조정해줘야함 필요할때마다
     supporting_elastic_expand(train, run_manager, args, validate_func_dict)
 elif args.task == 'pixelshuffle_depth':
     from ofa.elastic_nn.training.progressive_shrinking import supporting_elastic_pixelshuffle_depth
     # 해당함수가서 init model path 조정해줘야함 필요할때마다
示例#2
0
    distributed_run_manager.broadcast()

    # load teacher net weights
    if args.kd_ratio > 0:
        load_models(distributed_run_manager, args.teacher_model, model_path=args.teacher_path)

    # training
    from ofa.elastic_nn.training.progressive_shrinking import validate, train

    validate_func_dict = {'image_size_list': {224} if isinstance(args.image_size, int) else sorted({160, 224}),
                          'width_mult_list': sorted({0, len(args.width_mult_list) - 1}),
                          'ks_list': sorted({min(args.ks_list), max(args.ks_list)}),
                          'expand_ratio_list': sorted({min(args.expand_list), max(args.expand_list)}),
                          'depth_list': sorted({min(net.depth_list), max(net.depth_list)})}
    if args.task == 'kernel':
        validate_func_dict['ks_list'] = sorted(args.ks_list)
        if distributed_run_manager.start_epoch == 0:
            model_path = download_url('https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D4_E6_K7',
                                      model_dir='.torch/ofa_checkpoints/%d' % hvd.rank())
            load_models(distributed_run_manager, distributed_run_manager.net, model_path=model_path)
            distributed_run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
                                              validate(distributed_run_manager, **validate_func_dict), 'valid')
        train(distributed_run_manager, args,
              lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict))
    elif args.task == 'depth':
        from ofa.elastic_nn.training.progressive_shrinking import supporting_elastic_depth
        supporting_elastic_depth(train, distributed_run_manager, args, validate_func_dict)
    else:
        from ofa.elastic_nn.training.progressive_shrinking import supporting_elastic_expand
        supporting_elastic_expand(train, distributed_run_manager, args, validate_func_dict)