Example #1
0
def abstract_load_keras_model(folder,
                              dataset,
                              num_layer,
                              activation,
                              hidden_neurons,
                              mode=None):
    """
    Parse parameters to path in a friendly way
    :param folder: 'fastlin' or 'recurjac'
    :param dataset: 'cifar' or 'mnist'
    :param num_layer: positive integer
    :param activation: 'relu' or 'leaky' or 'tanh'
    :param hidden_neurons: positive integer
    :param mode: None or 'best' or 'adv_retrain' or 'distill'
    :return: the pytorch model
    """
    assert folder in ['fastlin', 'recurjac']
    assert dataset in ['cifar', 'mnist']
    assert type(num_layer) == int and num_layer > 0
    assert activation in ['relu', 'leaky', 'tanh']
    assert type(hidden_neurons) == int and hidden_neurons > 0
    assert mode is None or mode in ['best', 'adv_retrain', 'distill']
    path_str = f'models_weights/models_{folder}/{dataset}_{num_layer}layer_{activation}_{hidden_neurons}'

    input_shape = {'cifar': 'cifar10', 'mnist': 'mnist'}[dataset]
    input_shape = datasets.get_input_shape(input_shape)

    if mode is not None:
        path_str += f'_{mode}'
    return load_keras_model(input_shape, path_str)
Example #2
0
    def __init__(self, dataset, model):
        super(CNNCertBase, self).__init__(dataset, model)

        self.num_classes = get_num_classes(dataset)

        input_shape = get_input_shape(dataset)
        new_input_shape = (input_shape[1], input_shape[2], input_shape[0])
        self.k_model = sequential_torch2keras(model_transform(self.model, input_shape), dataset)

        global graph
        global sess
        with sess.as_default():
            with graph.as_default():
                # Save the transformed Keras model to a temporary place so that the tool can read from file
                # The tool can only init model from file...
                sgd = SGD(lr=0.01, decay=1e-5, momentum=0.9, nesterov=True)
                self.k_model.compile(loss=fn,
                              optimizer=sgd,
                              metrics=['accuracy'])
                self.k_model.save('tmp/tmp.pt')

                self.new_model = nl.CNNModel('tmp/tmp.pt', new_input_shape)
                self.weights = self.new_model.weights
                self.biases = self.new_model.biases

                # print(self.new_model.summary())
                try:
                    check_consistency(self.model, self.k_model, input_shape)
                except Exception:
                    raise Exception("Somehow the transformed model behaves differently from the original model.")

        self.LP = False
        self.LPFULL = False
        self.method = "ours"
        self.dual = False
Example #3
0
    def __init__(self, dataset, model):
        super(PGDAdaptor, self).__init__(dataset, model)

        self.config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            per_process_gpu_memory_fraction=0.5))
        self.config.gpu_options.allow_growth = True
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph, config=self.config)

        input_shape = get_input_shape(dataset)

        with self.sess.graph.as_default():
            with self.sess.as_default():
                self.tf_model = convert_pytorch_model_to_tf(self.model)
                self.ch_model = CallableModelWrapper(self.tf_model,
                                                     output_layer='logits')

                self.x_op = tf.placeholder(tf.float32,
                                           shape=(
                                               None,
                                               input_shape[0],
                                               input_shape[1],
                                               input_shape[2],
                                           ))
                self.attk = ProjectedGradientDescent(self.ch_model,
                                                     sess=self.sess)

        self.adv_preds_ops = dict()
Example #4
0
    def __init__(self, dataset, model):
        super(CNNCertBase, self).__init__(dataset, model)

        self.num_classes = get_num_classes(dataset)

        self.activation = list()
        for layer in self.model:
            if isinstance(layer, nn.ReLU):
                self.activation.append('ada')
            elif isinstance(layer, nn.Sigmoid):
                self.activation.append('sigmoid')
            elif isinstance(layer, nn.Tanh):
                self.activation.append('tanh')
        # actually there is another activation called arctan,
        # but there is no corresponding one in pytorch so we ignore it
        self.activation = list(set(self.activation))
        assert len(self.activation) == 1
        self.activation = self.activation[0]

        input_shape = get_input_shape(dataset)
        new_input_shape = (input_shape[1], input_shape[2], input_shape[0])
        self.k_model = sequential_torch2keras(self.model, dataset)

        global graph
        global sess
        with sess.as_default():
            with graph.as_default():

                print(self.k_model.summary())
                try:
                    assert check_consistency(self.model, self.k_model, input_shape) == True
                except:
                    raise Exception("Somehow the transformed model behaves differently from the original model.")

                self.new_model = Model(self.k_model, new_input_shape)

        # Set correct linear_bounds function
        self.linear_bounds = None
        if self.activation == 'relu':
            self.linear_bounds = relu_linear_bounds
        elif self.activation == 'ada':
            self.linear_bounds = ada_linear_bounds
        elif self.activation == 'sigmoid':
            self.linear_bounds = sigmoid_linear_bounds
        elif self.activation == 'tanh':
            self.linear_bounds = tanh_linear_bounds
        elif self.activation == 'arctan':
            self.linear_bounds = atan_linear_bounds
Example #5
0
    def __init__(self, dataset, model):
        super(ERANBase, self).__init__(dataset, model)

        self.model.eval()

        # export the model to onnx file 'tmp/tmp.onnx'
        input_shape = get_input_shape(dataset)
        x = torch.randn(1,
                        input_shape[0],
                        input_shape[1],
                        input_shape[2],
                        requires_grad=True).cuda()
        torch_out = self.model(x)

        torch.onnx.export(self.model, x, 'tmp/tmp.onnx')
        # export_params=True,
        # opset_version=10,
        # do_constant_folding=True,
        # input_names=['input'],
        # output_names=['output'],
        # dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

        model_opt, check_ok = onnxsim.simplify('tmp/tmp.onnx',
                                               check_n=3,
                                               perform_optimization=True,
                                               skip_fuse_bn=True,
                                               input_shapes={
                                                   None: (1, input_shape[0],
                                                          input_shape[1],
                                                          input_shape[2])
                                               })
        assert check_ok, "Simplify ONNX model failed"

        onnx.save(model_opt, 'tmp/tmp_simp.onnx')
        new_model, is_conv = read_onnx_net('tmp/tmp_simp.onnx')
        self.new_model, self.is_conv = new_model, is_conv

        self.config = {}

        self.eran = ERAN(self.new_model, is_onnx=True)
Example #6
0
    def __init__(self, dataset, model):
        self.model, self.activation, self.activation_param = torch2keras(dataset,
                                                                         model_transform(model, datasets.get_input_shape(dataset)))

        print(self.model.summary())

        with sess.as_default():
            with graph.as_default():
                try:
                    assert check_consistency(model, self.model, datasets.get_input_shape(dataset), 'channels_first') == True
                except:
                    raise Exception("Somehow the transformed model behaves differently from the original model.")

        with sess.as_default():
            with graph.as_default():
                # extract weights
                self.U = list()
                for layer in self.model.layers:
                    if isinstance(layer, keras.layers.Dense):
                        self.U.append(layer)

                self.W = self.U[-1]
                self.U = self.U[:-1]

                layer_outputs = []
                # save the output of intermediate layers
                for layer in self.model.layers:
                    if isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Dense):
                        layer_outputs.append(K.function([self.model.layers[0].input], [layer.output]))

                # a tensor to get gradients
                self.gradients = []
                for i in range(self.model.output.shape[1]):
                    output_tensor = self.model.output[:, i]
                    self.gradients.append(K.gradients(output_tensor, self.model.input)[0])

                self.layer_outputs = layer_outputs
                self.model.summary()
Example #7
0
def sequential_torch2keras(torch_model, dataset):
    """
        Transform the sequential torch model on CUDA to Keras
    :param torch_model: the Torch model to transform
    :param dataset: the dataset, typically 'MNIST' or 'CIFAR10'
    :return: the transformed Keras model
    """

    global graph
    global sess
    graph = tf.Graph()
    sess = tf.Session(graph=graph,
                      config=tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5)))

    with sess.as_default():
        with graph.as_default():
            ret = keras.Sequential()

            assert isinstance(torch_model, nn.Sequential)

            input_shape = get_input_shape(dataset)
            new_input_shape = (input_shape[1], input_shape[2], input_shape[0])

            # before meeting the flatten layer, we transform each channel-first layer to the corresponding channel-last one
            # after meeting the flatten layer, we analyze the first layer and transform the weight matrix
            meet_flatten = False
            shape_before_flatten = None
            transposed = False
            first_layer = True

            for layer in torch_model:
                if first_layer:
                    kwargs = {'input_shape': new_input_shape}
                    first_layer = False
                else:
                    kwargs = {}

                if isinstance(layer, nn.Conv2d):
                    # we don't permit conv2d layer after flatten
                    assert meet_flatten is False
                    # by default, we assume zero-padding size is to allow the "same" padding configuration in keras.Conv2D
                    # since cnn-cert only supports keras.Conv2D but not zero padding layers
                    # print('  in', layer.in_channels)
                    # print('  out', layer.out_channels)
                    # print('  stride', layer.stride)
                    # print('  padding', layer.padding)
                    # print('  paddingmode', layer.padding_mode)
                    # print('  kernelsize', layer.kernel_size)
                    # print('  weight shape', layer.weight.size())
                    # if layer.bias is not None:
                    #     print('  bias shape', layer.bias.size())

                    new_layer = keras.layers.Conv2D(layer.out_channels, layer.kernel_size, layer.stride,
                                                    'valid' if layer.padding[0] == 0 else 'same',
                                                    'channels_last',
                                                    use_bias=layer.bias is not None,
                                                    **kwargs)

                    ret.add(new_layer)
                    # print(ret.output_shape)

                    new_weights = [layer.weight.cpu().detach().numpy().transpose(2, 3, 1, 0)]
                    if layer.bias is not None:
                        new_weights.append(layer.bias.cpu().detach().numpy())
                    new_layer.set_weights(new_weights)

                    # print('  new weight/bias len:', len(new_layer.get_weights()))
                    # print('  new weight shape:', new_layer.get_weights()[0].shape)
                    # print('  new bias shape:', new_layer.get_weights()[1].shape)

                elif isinstance(layer, nn.AvgPool2d):
                    # we don't permit avgpool2d layer after flatten
                    assert meet_flatten is False

                    new_layer = keras.layers.AvgPool2D(layer.kernel_size, layer.stride,
                                                       'valid' if layer.padding[0] == 0 else 'same',
                                                       data_format='channels_last',
                                                       **kwargs)
                    ret.add(new_layer)

                elif isinstance(layer, nn.MaxPool2d):
                    # we don't permit maxpool2d layer after flatten
                    assert meet_flatten is False

                    new_layer = keras.layers.MaxPool2D(layer.kernel_size, layer.stride,
                                                       'valid' if layer.padding[0] == 0 else 'same',
                                                       data_format='channels_last', **kwargs)
                    ret.add(new_layer)

                elif isinstance(layer, nn.ReLU):
                    ret.add(keras.layers.Activation('relu', **kwargs))

                elif isinstance(layer, nn.Tanh):
                    ret.add(keras.layers.Activation('tanh', **kwargs))

                elif isinstance(layer, nn.Sigmoid):
                    ret.add(keras.layers.Activation('sigmoid', **kwargs))

                elif isinstance(layer, models.zoo.Flatten):
                    meet_flatten = True
                    transposed = False
                    if 'input_shape' in kwargs:
                        shape_before_flatten = new_input_shape
                    else:
                        shape_before_flatten = ret.output_shape[1:]
                    ret.add(keras.layers.Flatten(data_format='channels_last', **kwargs))

                elif isinstance(layer, nn.Linear) or isinstance(layer, FlattenConv2D):
                    # print('  in dim', layer.in_features)
                    # print('  out dim', layer.out_features)
                    weights = [layer.weight.cpu().detach().numpy().T]
                    if layer.bias is not None:
                        weights.append(layer.bias.cpu().detach().numpy())
                    # print([x.shape for x in weights])

                    new_layer = keras.layers.Dense(layer.out_features, **kwargs)
                    ret.add(new_layer)
                    # print([x.shape for x in new_layer.get_weights()])

                    if meet_flatten and not transposed:
                        # print('transposed here')
                        h, w, c = shape_before_flatten
                        mapping = [k * h * w + i * w + j for i in range(h) for j in range(w) for k in range(c)]
                        # print(mapping)
                        weights[0] = weights[0][mapping]
                        # print([x.shape for x in weights])
                        transposed = True
                    new_layer.set_weights(weights)

                elif isinstance(layer, nn.Dropout):
                    rate = layer.p
                    ret.add(keras.layers.Dropout(rate))

                else:
                    raise Exception(f'Unsupported layer type {layer.__class__.__name__}')

    return ret
Example #8
0
    if args.mode == 'transform':
        for now_dataset in args.dataset:
            print(f'Now on {now_dataset}', file=sys.stderr)
            for now_model in args.model:
                print(f'  Now on {now_model}', file=sys.stderr)
                for now_w, now_w_fname, _ in weights[now_dataset]:
                    print(f'    Now on {now_w}', file=sys.stderr)

                    m = load_model('exp', now_dataset, now_model)
                    m, flag, others = try_load_weight(
                        m, f'{now_dataset}_{now_model}_{now_w_fname}_best')
                    assert flag

                    t1 = time.time()
                    print(f'      Transforming...')
                    newm = model_transform(m, get_input_shape(now_dataset))
                    print(f'      Saving... {time.time() - t1:.3f} s')
                    torch.save(
                        {
                            'state_dict': newm.state_dict(),
                            'acc': others['acc'],
                            'robacc': others['robacc'],
                            'epoch': others['epoch'],
                            'normalized': others['normalized'],
                            'dataset': others['normalized']
                        },
                        f'{SAVE_PATH}/transformed_{now_dataset}_{now_model}_{now_w_fname}_best.pth'
                    )
                    print(f'      Done {time.time() - t1:.3f} s')
    if args.mode == 'test':
        raise NotImplementedError
Example #9
0
 def __init__(self, dataset, model, timeout=30):
     super(FastMILPAdaptor, self).__init__(dataset, model)
     # cp.settings.SOLVE_TIME = timeout
     in_shape = get_input_shape(dataset)
     self.prepare_solver(in_shape)
Example #10
0
 def __init__(self, dataset, model):
     super(IBPAdaptor, self).__init__(dataset, model)
     in_shape = get_input_shape(dataset)
     self.bound = FastIntervalBound(self.model, in_shape, self.in_min,
                                    self.in_max)
Example #11
0
def torch2keras(dataset, model):
    with sess.as_default():
        with graph.as_default():
            input_shape = datasets.get_input_shape(dataset)
            ans = keras.Sequential()
            n = 0
            activation, activation_param = list(), None
            first_layer = True

            for layer in model:

                if first_layer:
                    kwargs = {'input_shape': input_shape}
                    first_layer = False
                else:
                    kwargs = {}

                n += 1
                if isinstance(layer, Flatten):
                    ans.add(keras.layers.Flatten('channels_last', **kwargs))
                elif isinstance(layer, nn.Linear) or isinstance(layer, FlattenConv2D):
                    i, o = layer.in_features, layer.out_features
                    l = keras.layers.Dense(o)
                    ans.add(l)
                    l.set_weights([layer.weight.t().cpu().detach().numpy(), layer.bias.cpu().detach().numpy()])
                elif isinstance(layer, nn.ReLU):
                    ans.add(keras.layers.Activation('relu', name=f'relu_{n}'))
                    activation.append('relu')
                elif isinstance(layer, nn.Tanh):
                    ans.add(keras.layers.Activation('tanh', name=f'tanh_{n}'))
                    activation.append('tanh')
                elif isinstance(layer, nn.LeakyReLU):
                    ans.add(keras.layers.LeakyReLU(alpha=layer.negative_slope, name=f'leaky_{n}'))
                    activation.append('leaky')
                    activation_param = layer.negative_slope
                elif isinstance(layer, nn.Dropout):
                    # ignore dropout layer since we only use the model for evaluation here
                    pass
                elif isinstance(layer, nn.Conv2d):
                    new_layer = keras.layers.Conv2D(layer.out_channels, layer.kernel_size, layer.stride,
                                                    'valid' if layer.padding[0] == 0 else 'same',
                                                    'channels_first',
                                                    use_bias=layer.bias is not None,
                                                    **kwargs)

                    ans.add(new_layer)
                    # print(ret.output_shape)
                    new_weights = [layer.weight.cpu().detach().numpy().transpose(2, 3, 1, 0)]
                    if layer.bias is not None:
                        new_weights.append(layer.bias.cpu().detach().numpy())
                    new_layer.set_weights(new_weights)
                else:

                    raise NotImplementedError

    # only one type of activation is permitted
    activation = list(set(activation))
    assert len(activation) == 1
    activation = activation[0]

    return ans, activation, activation_param
Example #12
0
                m = model.load_model('exp', now_ds_name, now_model).cuda()

                d = torch.load(
                    f'{SAVE_PATH}/{now_ds_name}_{now_model}_{now_weight_filename}_best.pth'
                )
                # print('acc:', d['acc'], 'robacc:', d['robacc'], 'epoch:', d['epoch'], 'normalized:', d['normalized'],
                #       'dataset:', d['dataset'], file=sys.stderr)

                model_parameters = filter(lambda p: p.requires_grad,
                                          m.parameters())
                params = sum([np.prod(p.size()) for p in model_parameters])
                # print(params)

                neurons, numparams = param_summary(
                    m, get_input_shape(now_ds_name))
                # print(neurons, numparams)
                struc = struc_summary(m)
                # print(struc)

                del m

                robacc_s = f"${d['robacc'] * 100.:4.2f}\%$" if d[
                    'robacc'] > 0. else "/"

                if i == 0:
                    tab_body_1.append([
                        dataset_show_names[now_ds_name],
                        model_show_names[now_model],
                        nicenum(neurons),
                        nicenum(numparams[0]), struc, sources[now_model]
Example #13
0
def main(train_method, dataset, model_name, params):
    # prepare dataset and normalize settings
    normalize = None
    if params.get('normalized', False):
        if dataset == 'mnist':
            normalize = (_MNIST_MEAN, _MNIST_STDDEV)
        elif dataset == 'cifar10':
            normalize = (_CIFAR10_MEAN, _CIFAR10_STDDEV)
        elif dataset == 'imagenet':
            normalize = (_IMAGENET_MEAN, _IMAGENET_STDDEV)
    train_set = get_dataset(dataset, 'train', normalize)
    test_set = get_dataset(dataset, 'test', normalize)

    # read input shape (c, h, w)
    input_shape = get_input_shape(dataset)

    # read params
    batch_size = params['batch_size']
    optimizer_name = params.get('optimizer', 'sgd')
    if optimizer_name == 'sgd':
        lr = params.get('learning_rate', 0.1)
        momentum = params.get('momentum', 0.1)
        weight_decay = params.get('weight_decay', 5e-4)
    elif optimizer_name == 'adam':
        lr = params.get('learning_rate', 0.1)
    else:
        raise NotImplementedError
    cur_lr = lr
    print('default learning rate =', cur_lr, file=stderr)
    start_epoch = 0
    epochs = params.get('epochs', 0)
    eps = normed_eps = params['eps']
    if train_method == 'adv':
        # Note: for adversarial training, in training phase, we use the manual implementation version for precision,
        # and use the clearhans implementation in test phase for precision
        eps_iter_coef = params['eps_iter_coef']
        clip_min = params['clip_min']
        clip_max = params['clip_max']
        if normalize is not None:
            mean, std = normalize
            clip_min = (clip_min - max(mean)) / min(std) - 1e-6
            clip_max = (clip_max - min(mean)) / min(std) + 1e-6
            normed_eps = eps / min(std)
        nb_iter = params['nb_iter']
        rand_init = params['rand_init']

        adv_params = {
            'eps': normed_eps,
            'clip_min': clip_min,
            'clip_max': clip_max,
            'eps_iter': eps_iter_coef * eps,
            'nb_iter': nb_iter,
            'rand_init': rand_init
        }
    elif train_method == 'certadv':
        # Note: for certified adversarially trained models, we test its accuracy still using PGD attack
        eps_iter_coef = params['eps_iter_coef']
        clip_min = params['clip_min']
        clip_max = params['clip_max']
        if normalize is not None:
            mean, std = normalize
            clip_min = (clip_min - max(mean)) / min(std) - 1e-6
            clip_max = (clip_max - min(mean)) / min(std) + 1e-6
            normed_eps = eps / min(std)
        nb_iter = params['nb_iter']
        rand_init = params['rand_init']

        adv_params = {
            'eps': normed_eps,
            'clip_min': clip_min,
            'clip_max': clip_max,
            'eps_iter': eps_iter_coef * eps,
            'nb_iter': nb_iter,
            'rand_init': rand_init
        }
        print(adv_params, file=stderr)

    # prepare loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size,
                                               shuffle=True,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size,
                                              shuffle=True,
                                              pin_memory=True)

    # stats
    train_tot = len(train_set)
    test_tot = len(test_set)

    best_acc = 0.0
    best_robacc = 0.0

    # load model
    m = model.load_model('exp', dataset, model_name).cuda()
    print(m)

    if train_method == 'adv' and params['retrain']:
        # retrain from the best clean model
        clean_model_name = f'{dataset}_{model_name}_clean_0_best'
        new_m, stats = try_load_weight(m, clean_model_name)
        assert stats == True, "Could not load pretrained clean model."
        if isinstance(new_m[0], NormalizeLayer):
            # squeeze the normalize layer out
            new_m = new_m[1]
        m = new_m
    elif train_method == 'certadv':
        configdir = params['configpath']
        ds_mapping = {'cifar10': 'cifar', 'mnist': 'mnist'}
        ds_multiplier = {'cifar10': 255., 'mnist': 10.}
        configfilename = f'exp_{ds_mapping[dataset]}{int(round(eps * ds_multiplier[dataset]))}.json'
        with open(os.path.join(configdir, configfilename), 'r') as f:
            real_config = json.load(f)
        epochs = real_config['training_params']['epochs']
        start_epoch = epochs - 1
        model_path = os.path.join(
            os.path.join(real_config['path_prefix'],
                         real_config['models_path']), f'{model_name}_best.pth')
        d = torch.load(model_path)
        print(f'certadv load from {model_path}', file=stderr)
        m.load_state_dict(d['state_dict'])

    # open file handler
    save_name = f'{ds}_{model_name}_{now_method}_{eps}'
    mode = 'a'
    if os.path.exists(f'{SAVE_PATH}/{save_name}_train.log') or os.path.exists(
            f'{SAVE_PATH}/{save_name}_test.log'):
        choice = getpass.getpass(
            f'Log exists. Do you want to rewrite it? (Y/others) ')
        if choice == 'Y':
            mode = 'w'
            print('Rewrite log', file=stderr)
        else:
            mode = 'a'
    train_log = open(f'{SAVE_PATH}/{save_name}_train.log', mode)
    test_log = open(f'{SAVE_PATH}/{save_name}_test.log', mode)

    # special treatment for model G - layerwise training
    if model_name == 'G' and train_method == 'adv':
        new_last_layer = nn.Linear(1024, 10)

    # start
    for epoch in range(start_epoch, epochs):

        if epoch % LR_REDUCE == 0 and epoch > 0:
            # learning rate reduced to LR_REDUCE_RATE every LR_REDUCE epochs
            cur_lr *= LR_REDUCE_RATE
            print(f'  reduce learning rate to {cur_lr}', file=stderr)

        # special treatment for model G - layerwise training
        if model_name == 'G' and train_method == 'adv':
            new_m = list()
            tmp_cnt = 0
            for l in m:
                new_m.append(l)
                if isinstance(l, nn.Linear) and l.out_features == 1024:
                    tmp_cnt += 1
                if tmp_cnt > epoch / 5:
                    if l.out_features == 1024:
                        new_m.append(nn.ReLU())
                        new_m.append(new_last_layer)
                    break
            new_m = nn.Sequential(*new_m).cuda()
            m, new_m = new_m, m
            print(m, file=stderr)
            cur_lr = lr
            print(f'  learning rate restored to {cur_lr}', file=stderr)

        # init optimizer
        if optimizer_name == 'adam':
            opt = optim.Adam(m.parameters(), lr=cur_lr)
        elif optimizer_name == 'sgd':
            opt = optim.SGD(m.parameters(),
                            lr=cur_lr,
                            momentum=momentum,
                            weight_decay=weight_decay)
        else:
            raise Exception("Fail to create the optimizer")

        cur_idx = 0
        cur_acc = 0.0
        cur_robacc = 0.0

        batch_tot = 0
        batch_acc_tot = 0
        batch_robacc_tot = 0

        clean_ce = 0.0
        adv_ce = 0.0

        # now eps
        now_eps = normed_eps * min((epoch + 1) / EPS_WARMUP_EPOCHS, 1.0)
        # =========== Training ===========
        print(f'Epoch {epoch}: training', file=stderr)
        if train_method != 'clean':
            print(f'  Training eps={now_eps:.3f}', file=stderr)
        m.train()

        for i, (X, y) in enumerate(train_loader):

            if DEBUG and i > 10:
                break

            start_t = time.time()

            X_clean, y_clean = X.cuda(), y.cuda().long()
            clean_out = m(Variable(X_clean))
            clean_ce = nn.CrossEntropyLoss()(clean_out, Variable(y_clean))

            batch_tot = X.size(0)
            batch_acc_tot = (
                clean_out.data.max(1)[1] == y_clean).float().sum().item()

            if train_method == 'clean':
                opt.zero_grad()
                clean_ce.backward()
                opt.step()

            elif train_method == 'adv':
                X_pgd = Variable(X, requires_grad=True)
                for _ in range(nb_iter):
                    opt_pgd = optim.Adam([X_pgd], lr=1e-3)
                    opt.zero_grad()
                    loss = nn.CrossEntropyLoss()(m(X_pgd.cuda()),
                                                 Variable(y_clean))
                    loss.backward()
                    eta = now_eps * eps_iter_coef * X_pgd.grad.data.sign()
                    X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
                    eta = torch.clamp(X_pgd.data - X, -now_eps, now_eps)
                    X_pgd.data = X + eta
                    X_pgd.data = torch.clamp(X_pgd.data, clip_min, clip_max)

                # print(X_pgd.data, la.norm((X_pgd.data - X).numpy().reshape(-1), np.inf), file=stderr)
                adv_out = m(Variable(X_pgd.data).cuda())
                adv_ce = nn.CrossEntropyLoss()(adv_out, Variable(y_clean))
                batch_robacc_tot = (
                    adv_out.data.max(1)[1] == y_clean).float().sum()

                opt.zero_grad()
                adv_ce.backward()
                opt.step()

            elif train_method == 'certadv':
                # no action to do for training
                adv_ce = torch.Tensor([0.0]).cuda()
                pass

            end_t = time.time()

            clean_ce = clean_ce.detach().cpu().item()
            if train_method != 'clean':
                adv_ce = adv_ce.detach().cpu().item()

            runtime = end_t - start_t
            cur_acc = (cur_acc * cur_idx + batch_acc_tot) / (cur_idx +
                                                             batch_tot)
            if train_method != 'clean':
                cur_robacc = (cur_robacc * cur_idx +
                              batch_robacc_tot) / (cur_idx + batch_tot)
            cur_idx += batch_tot

            print(
                f'{epoch} {cur_idx} {cur_acc} {cur_robacc} {batch_acc_tot/batch_tot:.3f} {batch_robacc_tot/batch_tot:.3f}'
                f' {clean_ce:.3f} {adv_ce:.3f} {runtime:.3f}',
                file=train_log)
            if i % STEP == 0 or cur_idx == train_tot:
                print(
                    f'  [train] {epoch}/{cur_idx} acc={cur_acc:.3f}({batch_acc_tot/batch_tot:.3f}) '
                    f'robacc={cur_robacc:.3f}({batch_robacc_tot/batch_tot:.3f}) ce={clean_ce:.3f} adv_ce={adv_ce:.3f} time={runtime:.3f}',
                    file=stderr)

        train_log.flush()

        # =========== Testing ===========
        print(f'Epoch {epoch}: testing', file=stderr)
        m.eval()
        torch.set_grad_enabled(False)

        cur_idx = 0
        cur_acc = 0.0
        cur_robacc = 0.0

        batch_tot = 0
        batch_acc_tot = 0
        batch_robacc_tot = 0

        clean_ce = 0.0
        adv_ce = 0.0

        if train_method in ['adv', 'certadv']:
            tf_model = convert_pytorch_model_to_tf(m)
            ch_model = CallableModelWrapper(tf_model, output_layer='logits')
            x_op = tf.placeholder(tf.float32,
                                  shape=(None, ) + tuple(input_shape))
            sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=0.5)))
            attk = ProjectedGradientDescent(ch_model, sess=sess)
            adv_x = attk.generate(x_op, **adv_params)
            adv_preds_op = tf_model(adv_x)

        for i, (X, y) in enumerate(test_loader):

            if DEBUG and i >= 10:
                break

            start_t = time.time()

            X_clean, y_clean = X.cuda(), y.cuda().long()
            clean_out = m(Variable(X_clean))
            clean_ce = nn.CrossEntropyLoss()(clean_out, Variable(y_clean))

            batch_tot = X.size(0)
            batch_acc_tot = (
                clean_out.data.max(1)[1] == y_clean).float().sum().item()

            if train_method in ['adv', 'certadv']:

                (adv_preds, ) = sess.run((adv_preds_op, ), feed_dict={x_op: X})
                adv_preds = torch.Tensor(adv_preds)

                adv_ce = nn.CrossEntropyLoss()(adv_preds, Variable(y))
                batch_robacc_tot = (
                    adv_preds.data.max(1)[1] == y).float().sum().item()

            # elif train_method == 'certadv':
            #
            #     adv_ce, robust_err = robust_loss(m, eps,
            #                                      Variable(X_clean), Variable(y_clean),
            #                                      proj=50, norm_type='l1_median', bounded_input=True)
            #
            #     batch_robacc_tot = (1.0 - robust_err) * batch_tot

            end_t = time.time()

            clean_ce = clean_ce.detach().cpu().item()
            if train_method != 'clean':
                adv_ce = adv_ce.detach().cpu().item()

            runtime = end_t - start_t
            cur_acc = (cur_acc * cur_idx + batch_acc_tot) / (cur_idx +
                                                             batch_tot)
            if train_method != 'clean':
                cur_robacc = (cur_robacc * cur_idx +
                              batch_robacc_tot) / (cur_idx + batch_tot)
            cur_idx += batch_tot

            print(
                f'{epoch} {cur_idx} {cur_acc} {cur_robacc} {batch_acc_tot / batch_tot:.3f} {batch_robacc_tot / batch_tot:.3f}'
                f' {clean_ce} {adv_ce} {runtime:.3f}',
                file=test_log)
            if i % STEP == 0 or cur_idx == train_tot:
                print(
                    f'  [test] {epoch}/{cur_idx} acc={cur_acc:.3f}({batch_acc_tot / batch_tot:.3f}) '
                    f'robacc={cur_robacc:.3f}({batch_robacc_tot / batch_tot:.3f}) time={runtime:.3f}',
                    file=stderr)

        torch.set_grad_enabled(True)

        if model_name == 'G' and train_method == 'adv':
            # switch back
            m, new_m = new_m, m

        def save_with_configs(m, path):
            torch.save(
                {
                    'state_dict': m.state_dict(),
                    'acc': cur_acc,
                    'robacc': cur_robacc,
                    'epoch': epoch,
                    'normalized': normalize is not None,
                    'dataset': dataset
                }, path)

        if not os.path.exists(f'{SAVE_PATH}/{save_name}_chkpt'):
            os.makedirs(f'{SAVE_PATH}/{save_name}_chkpt')
        save_with_configs(
            m, f'{SAVE_PATH}/{save_name}_chkpt/{save_name}_ep_{epoch:03d}.pth')
        if (train_method == 'clean'
                and cur_acc > best_acc) or (train_method != 'clean'
                                            and cur_robacc > best_robacc):
            save_with_configs(m, f'{SAVE_PATH}/{save_name}_best.pth')
            print(
                f"  Updated, acc {best_acc:.3f} => {cur_acc:.3f} robacc {best_robacc:.3f} => {cur_robacc:.3f}",
                file=stderr)
            best_acc = cur_acc
            best_robacc = cur_robacc

        test_log.flush()

        # memory clean after each batch
        torch.cuda.empty_cache()
        if train_method == 'adv':
            sess.close()

    train_log.close()
    test_log.close()