def main():
    args = parser.parse_args()
    args.seed = None  # temp moran
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'),
                  args.experiment,
                  args,
                  name_args=[
                      args.arch, args.dataset,
                      "W{}A{}".format(args.bit_weights, args.bit_act)
                  ]) as ml_logger:
        main_worker(args, ml_logger)
def main():
    args = parser.parse_args()
    if args.config_file is not None:
        with open(args.config_file) as f:
            config_dict = json.loads(f.read())
        parser.set_defaults(**config_dict)
        args = parser.parse_args()

    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    # if args.evaluate:
    #     args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    args.save_path = path.join(args.results_dir, args.save)

    if args.exp is '':
        exp = str(args.model) + '_' + str(args.dataset)
    else:
        exp = args.exp

    with MLlogger(path.join(args.results_dir, 'mlruns'),
                  exp,
                  args,
                  name_args=[args.model, args.dataset]) as ml_logger:
        main_worker(args, ml_logger)
def main():
    args = parser.parse_args()
    args.post_relu = not args.pre_relu

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'),
                  args.experiment,
                  args,
                  name_args=[
                      args.arch, args.dataset,
                      "W{}A{}".format(args.bit_weights, args.bit_act)
                  ]) as ml_logger:
        main_worker(args, ml_logger)
    mq.set_clipping(scales, inf_model.device)
    loss = inf_model.evaluate_calibration()
    ml_logger.log_metric('Loss {}'.format(args.min_method), loss.item(), step='auto')

    # evaluate
    acc = inf_model.validate()
    ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto')
    data['powell'] = {'alpha': scales, 'loss': loss.item(), 'acc': acc}

    # save scales
    f_name = "scales_{}_W{}A{}.pkl".format(args.arch, args.bit_weights, args.bit_act)
    f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb')
    pickle.dump(data, f)
    f.close()
    print("Data saved to {}".format(f_name))


if __name__ == '__main__':
    args = parser.parse_args()
    if args.cal_batch_size is None:
        args.cal_batch_size = args.batch_size
    if args.cal_batch_size > args.batch_size:
        print("Changing cal_batch_size parameter from {} to {}".format(args.cal_batch_size, args.batch_size))
        args.cal_batch_size = args.batch_size
    if args.cal_set_size is None:
        args.cal_set_size = args.batch_size

    with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'), args.experiment, args,
                  name_args=[args.arch, args.dataset, "W{}A{}".format(args.bit_weights, args.bit_act)]) as ml_logger:
        main(args, ml_logger)
            'bcorr_weight': args.bias_corr_weight,
            'vcorr_weight': args.var_corr_weight,
            'logger': logger,
            'measure_entropy': args.measure_entropy,
            'mtd_quant': args.mid_thread_quant
        },
        'qmanager':{
            'rho_act': args.rho_act,
            'rho_weight': args.rho_weight
        }
    }  # TODO: add params for bfloat
    return qparams


if __name__ == '__main__':
    if args.stats_mode != 'collect':
        experiment = args.arch if args.mlf_experiment is None else args.mlf_experiment
        with MLlogger(os.path.join(home, 'mlruns_mxt'), experiment, args,
                      name_args=[args.arch, "W{}A{}".format(args.qweight, args.qtype)]) as ml_logger:
            with QM(args, get_params(ml_logger)):
                im = InferenceModel(ml_logger)
                im.run()

            if args.measure_entropy:
                for id in ml_logger.metters:
                    print("Average bit rate: {} - {}".format(id, ml_logger.metters[id].avg))
    else:
        with QM(args, get_params()):
            im = InferenceModel()
            im.run()