예제 #1
0
    def ODE(
        self,
        t_,
        y,
    ):
        rx, ry, rz, vx, vy, vz, mass = y
        r = np.array([rx, ry, rz])
        v = np.array([vx, vy, vz])

        #norm of the radius vector, linalg is a sub library of numpy for linear algebra
        norm_r = np.linalg.norm(r)

        #law of gravitation, as r is vector a has output as a vector
        a = -r * self.cb['mu'] / norm_r**3

        #J2 accelleration calculation
        if self.perts['J2']:

            z2 = r[2]**2
            r2 = norm_r**2
            tx = r[0] / norm_r * (5 * z2 / r2 - 1)
            ty = r[1] / norm_r * (5 * z2 / r2 - 1)
            tz = r[2] / norm_r * (5 * z2 / r2 - 3)

            #J2 accelleration added to the accelleration vector
            a += 1.5 * self.cb['J2'] * self.cb['mu'] * self.cb[
                'radius']**2.0 / norm_r**4.0 * np.array([tx, ty, tz])

        #calculate aerodynamic drag
        if self.perts['aerodrag']:

            #calculate altitude and air density
            z = norm_r - self.cb['radius']  # find altitude
            rho = t.calc_atmospheric_density(
                z)  #find air density at given altitude

            #calculate motion of s/c with repsect to a rotating atmosphere
            v_rel = v - np.cross(self.cb['atm_rot_vector'], r)

            #aerodynamic drag calculation
            drag = -v_rel * 0.5 * rho * t.norm(
                v_rel) * self.perts['Cd'] * self.perts['A'] / mass

            #addition of the the aerodrag  to the accelleration vector
            a += drag

        #calculate thrust
        if self.perts['thrust']:

            #thrust calculation using newtons 2nd law with vnorm/v to calcuate thrust direction.
            a += (self.perts['thrust_direction'] * t.normed(v) *
                  self.perts['thrust'] / mass) / 1000.0

            #calculates mass flow rate
            mass_flow = -self.perts['thrust'] / (self.perts['isp'] * 9.81)

        #returns the [7x1] solution vector
        return [vx, vy, vz, a[0], a[1], a[2], mass_flow]
예제 #2
0
def train_revision(model, train_loader, epoch, optimizer, W_group,
                   basis_matrix_group, batch_size, num_classes, basis):
    print('Training %s...' % model_str)

    train_total = 0
    train_correct = 0

    for i, (data, labels) in enumerate(train_loader):

        data = data.cuda()
        labels = labels.cuda()
        loss = 0.
        # Forward + Backward + Optimize
        optimizer.zero_grad()
        _, logits, revision = model(data, revision=True)

        logits_ = F.softmax(logits, dim=1)
        logits_correction_total = torch.zeros(len(labels), num_classes)
        for j in range(len(labels)):
            idx = i * batch_size + j
            matrix = matrix_combination(basis_matrix_group, W_group, idx,
                                        num_classes, basis)
            matrix = torch.from_numpy(matrix).float().cuda()
            matrix = tools.norm(matrix + revision)

            logits_single = logits_[j, :].unsqueeze(0)
            logits_correction = logits_single.mm(matrix)
            pro1 = logits_single[:, labels[j]]
            pro2 = logits_correction[:, labels[j]]
            beta = pro1 / pro2
            logits_correction = torch.log(logits_correction + 1e-12)
            logits_single = torch.log(logits_single + 1e-12)
            loss_ = beta * F.nll_loss(logits_single, labels[j].unsqueeze(0))
            loss += loss_
            logits_correction_total[j, :] = logits_correction
        logits_correction_total = logits_correction_total.cuda()
        loss = loss / len(labels)
        prec1, = accuracy(logits_correction_total, labels, topk=(1, ))
        train_total += 1
        train_correct += prec1

        loss.backward()
        optimizer.step()

        if (i + 1) % args.print_freq == 0:
            print(
                'Epoch [%d/%d], Iter [%d/%d] Train Accuracy: %.4F, Loss: %.4f'
                % (epoch + 1, args.n_epoch_3, i + 1,
                   len(train_dataset) // batch_size, prec1, loss.item()))

    train_acc = float(train_correct) / float(train_total)
    return train_acc
예제 #3
0
def val_revision(model, train_loader, epoch, W_group, basis_matrix_group,
                 batch_size, num_classes, basis):

    val_total = 0
    val_correct = 0

    for i, (data, labels) in enumerate(train_loader):
        model.eval()
        data = data.cuda()
        labels = labels.cuda()
        loss = 0.
        # Forward + Backward + Optimize

        _, logits, revision = model(data, revision=True)

        logits_ = F.softmax(logits, dim=1)
        logits_correction_total = torch.zeros(len(labels), num_classes)
        for j in range(len(labels)):
            idx = i * batch_size + j
            matrix = matrix_combination(basis_matrix_group, W_group, idx,
                                        num_classes, basis)
            matrix = torch.from_numpy(matrix).float().cuda()
            matrix = tools.norm(matrix + revision)
            logits_single = logits_[j, :].unsqueeze(0)
            logits_correction = logits_single.mm(matrix)
            pro1 = logits_single[:, labels[j]]
            pro2 = logits_correction[:, labels[j]]
            beta = Variable(pro1 / pro2, requires_grad=True)
            logits_correction = torch.log(logits_correction + 1e-12)
            loss_ = beta * F.nll_loss(logits_correction,
                                      labels[j].unsqueeze(0))
            loss += loss_
            logits_correction_total[j, :] = logits_correction
        logits_correction_total = logits_correction_total.cuda()
        prec1, = accuracy(logits_correction_total, labels, topk=(1, ))
        val_total += 1
        val_correct += prec1
        if (i + 1) % args.print_freq == 0:
            print(
                'Epoch [%d/%d], Iter [%d/%d] Val Accuracy: %.4F, Loss: %.4f' %
                (epoch + 1, args.n_epoch_3, i + 1,
                 len(val_dataset) // batch_size, prec1, loss.item()))

    val_acc = float(val_correct) / float(val_total)

    return val_acc
def Squeeze_excitation_layer(input_x, out_dim, ratio, layer_name, is_training):
    with tf.variable_scope(layer_name):

        squeeze = slim.avg_pool2d(input_x,
                                  input_x.get_shape()[1:3],
                                  padding='VALID',
                                  scope=layer_name + 'AvgPool')

        excitation = slim.fully_connected(squeeze,
                                          int(out_dim / ratio),
                                          activation_fn=None,
                                          scope=layer_name +
                                          '_fully_connected1')
        excitation = slim.dropout(excitation,
                                  0.5,
                                  is_training=is_training,
                                  scope=layer_name + 'Dropout_1')
        excitation = tf.nn.relu(excitation, name=layer_name + '_relu')
        excitation = slim.fully_connected(excitation,
                                          int(out_dim),
                                          activation_fn=None,
                                          scope=layer_name +
                                          '_fully_connected2')
        excitation = slim.dropout(excitation,
                                  0.5,
                                  is_training=is_training,
                                  scope=layer_name + 'Dropout_2')
        excitation = tf.nn.sigmoid(excitation, name=layer_name + '_sigmoid')

        excitation = tf.reshape(excitation, [-1, 1, 1, out_dim])
        scale = input_x * excitation
        group_net = tf.reduce_mean(scale, axis=3, keep_dims=True)
        group_net = gauss(group_net)

        group_net = slim.flatten(group_net)
        temp_list = list()
        for i in range(batch_size):
            temp_img = group_net[i, :]
            temp_img = tf.reshape(norm(temp_img), (8, 8))
            temp_img = tf.expand_dims(temp_img, axis=0)
            temp_list.append(temp_img)
        group_net = tf.concat(temp_list, axis=0)
        group_net = tf.expand_dims(group_net, axis=3)
        return scale, group_net
예제 #5
0
    def propagate_orbit(self):
        print('Propagating orbit...')

        #propagate orbit. check for max time and stop conditions at each time step
        while self.solver.successful(
        ) and self.step < self.n_steps and self.check_stop_conditions():

            # propogate orbit integrator step
            self.solver.integrate(self.solver.t + self.time_step)
            self.step += 1
            self.ts[self.step] = self.solver.t
            self.y[self.step] = self.solver.y
            self.alts[self.step] = t.norm(
                self.solver.y[:3]) - self.cb['radius']

        #extract the position array(60x6) we want all rows and all steps up to up to coloum 0,1,2 etc.
        self.ts = self.ts[:self.step]
        self.rs = self.y[:self.step, :3]
        self.vs = self.y[:self.step, 3:6]
        self.masses = self.y[:self.step, 6]
        self.alts = self.alts[:self.step]
예제 #6
0
    def _draw_arrow(self, start, end, colour, width=1, rel_spoke_len=0.2):
        """ internal function for drawing arrows
        draws directly to screen coordiantes,
        so make sure start and end have been converted

        plot._draw_arrow( start: np.array, end: np.array, colour: rgb or name
                width=1: int, rel_spoke_len=0.2: float)

        rel_spoke_len is the length of the spokes relative
                        to the length of the line from
                        start to end
        """
        if str(colour) in clrs.info:
            colour = clrs[colour]

        strt, endp = np.array(start), np.array(end)
        diff = endp - strt
        norm_diff = tools.norm(diff)
        perp = tools.perp(tools.normalize(diff))

        midpoint = strt + diff * 0.7

        pnt2 = midpoint - (perp * norm_diff * rel_spoke_len)
        pnt1 = midpoint + (perp * norm_diff * rel_spoke_len)

        #pgdraw.line(self.screen, colour, pnt1, pnt2)

        # draw the arrow
        #print(strt, endp, pnt1, pnt2)
        if width != 1:
            pgdraw.line(self.screen, colour, strt, endp, width)  # main line
            pgdraw.line(self.screen, colour, pnt1, endp, width)  # spoke one
            pgdraw.line(self.screen, colour, pnt2, endp, width)  # spoke two
        else:
            pgdraw.aaline(self.screen, colour, strt, endp, width)  # main line
            pgdraw.aaline(self.screen, colour, pnt1, endp, width)  # spoke one
            pgdraw.aaline(self.screen, colour, pnt2, endp, width)  # spoke two
예제 #7
0
    def forward(self, x, togray, sl_enc):

        # from gan type image to resnet5o type image
        x = F.interpolate(x, [256, 128], mode='bilinear')
        x = denorm(x, mean=[0.5] * 3, std=[0.5] * 3)
        x = norm(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        #
        if togray:
            x = rgb2gray(x, de_norm=False, norm=False)

        # forward
        feature_maps = self.resnet_conv1(x)

        # return
        if sl_enc:
            return feature_maps, None, None
        else:
            features_vectors = self.gap(self.resnet_conv2(feature_maps)).squeeze()
            bned_features, cls_score = self.classifier(features_vectors)
            if self.training:
                return feature_maps, features_vectors, cls_score
            else:
                return feature_maps, bned_features, None
예제 #8
0
    def __init__(self,
                 initial_state,
                 time_span,
                 time_step,
                 coes=False,
                 deg=True,
                 mass0=0,
                 perts=null_perts(),
                 cb=pd.earth,
                 propagator='lsoda',
                 sc={}):

        #if coes have been defined as input run this:
        if coes:
            self.r0, self.v0, _ = t.coes2rv(initial_state,
                                            deg=deg,
                                            mu=cb['mu'])

        #else if position and velocity vector defined as input run this.
        else:
            self.r0 = initial_state[:3]
            self.v0 = initial_state[3:]
        self.cb = cb
        self.time_step = time_step
        self.mass0 = mass0
        self.time_span = time_span

        #ceil function rounds float up to nearest whole number and int. transforms the float to a interger
        self.n_steps = int(np.ceil(self.time_span / self.time_step)) + 1

        #initialise arrays
        self.ts = np.zeros((self.n_steps + 1, 1))
        self.y = np.zeros((self.n_steps + 1, 7))
        self.alts = np.zeros((self.n_steps + 1))

        #7 states (vx,vy,vz,ax,ay,az, mass) preallocating memory (instead of creating a new list, it allowed memory to overwrite existing list
        self.propagator = propagator
        self.step = 0

        #initial condition at first step
        self.y[0, :] = self.r0.tolist() + self.v0.tolist() + [self.mass0]
        self.alts[0] = t.norm(self.r0) - self.cb['radius']

        # initiate solver (lsoda)fast, high order
        self.solver = ode(self.ODE)

        # Adam-Bashford multistep
        self.solver.set_integrator(self.propagator)

        # initial state at t0 defined
        self.solver.set_initial_value(self.y[0, :], 0)

        self.perts = perts

        #store stop conditions and dictionary
        self.stop_conditions_dict = sc

        #define dictionary to map internals method
        self.stop_conditions_map = {'min_alt': self.check_min_alt}

        #create stop conditions function list with deorbit always checked
        self.stop_condition_functions = [self.check_deorbit]

        #fill in the rest of the stop conditions.
        for key in self.stop_conditions_dict.keys():
            if key in self.stop_conditions_map:
                self.stop_condition_functions.append(
                    self.stop_conditions_map[key])

        #propagate the orbit
        self.propagate_orbit()
예제 #9
0
            if index <= index_num:
                A[epoch][index * args.batch_size:(index + 1) *
                         args.batch_size, :] = out
            else:
                A[epoch][index_num * args.batch_size, len(train_data), :] = out

val_acc_array = np.array(val_acc_list)
model_index = np.argmax(val_acc_array)

A_save_dir = prob_save_dir + '/' + 'prob.npy'
np.save(A_save_dir, A)
prob_ = np.load(A_save_dir)

transition_matrix_ = tools.fit(prob_[model_index, :, :], args.num_classes,
                               estimate_state)
transition_matrix = tools.norm(transition_matrix_)

matrix_path = matrix_save_dir + '/' + 'transition_matrix.npy'
np.save(matrix_path, transition_matrix)
T = torch.from_numpy(transition_matrix).float().cuda()

# initial parameters

estimate_model_path = model_save_dir + '/' + 'epoch_%s.pth' % (model_index + 1)
estimate_model_path = torch.load(estimate_model_path)
model.load_state_dict(estimate_model_path)

print('Estimate finish.....Training......')

for epoch in range(args.n_epoch):
    print('epoch {}'.format(epoch + 1))
예제 #10
0
opstrs = ['x', 'n']
s0 = Signal(x=x_guess, n=n_guess)

##define quantum basis and the initial state.
qbasis_full = get_hcb_basis(L)
qbasis_symm = get_hcb_basis(L, pblock=1, kblock=0)
from tools import get_Z2, get_uniform_state
psi0_full = (1.0 / np.sqrt(2)) * (get_Z2(qbasis_full, dtype, which=0) +
                                  get_Z2(qbasis_full, dtype, which=1))
proj = np.transpose(qbasis_symm.get_proj(dtype))
psi_target = proj.dot(psi0_full)
psi0 = proj.dot(get_uniform_state(qbasis_full, dtype, which=0))

from tools import norm
from numpy.testing import assert_almost_equal
assert_almost_equal(norm(psi0), 1.0, decimal=10)
assert_almost_equal(norm(psi0), 1.0, decimal=10)

##construct the cost function
from crab import UniformFieldFidelityCC
cConstr = UniformFieldFidelityCC(qbasis_symm, static, psi0, [ti, tf],
                                 psi_target, opstrs)

##set up fourier basis.
Nmode = 3
bx = RandomFourierBasis(Nmode, T, bc='unit', label='x')
bn = RandomFourierBasis(Nmode, T, bc='unit', label='n')
msb = MultipleSignalBasis(x=bx, n=bn)

options_nm = dict(maxfev=100)
예제 #11
0
def get_twist_writhe(spline1,
                     spline2,
                     npoints=1000,
                     circular=False,
                     integral_type="simple"):
    """
    return the twist and writhe for a given configuration and 
    Using integral_type = 'simple' 

    args:
    spline1: list of 3 splines corresponding to x, y and z spline through strand 1's backbone
    spline2: list of 3 splines corresponding to x, y and z spline through strand 2's backbone -- NB the splines should run in the same direction, i.e. one must reverse one of the splines if they come from get_base_spline (e.g. use get_base_spline(reverse = True))

    npoints: number of points for the discrete integration
    """

    import scipy.interpolate

    s1xx, s1yy, s1zz = spline1[0]
    s2xx, s2yy, s2zz = spline2[0]

    smin = spline1[1][0]
    smax = spline1[1][-1]

    # bpi is the base pair index parameter that is common to both splines
    bpi = np.linspace(smin, smax, npoints)

    # find the midpoint between the input splines, as a function of base pair index
    m1xx = (scipy.interpolate.splev(bpi, s1xx) +
            scipy.interpolate.splev(bpi, s2xx)) / 2
    m1yy = (scipy.interpolate.splev(bpi, s1yy) +
            scipy.interpolate.splev(bpi, s2yy)) / 2
    m1zz = (scipy.interpolate.splev(bpi, s1zz) +
            scipy.interpolate.splev(bpi, s2zz)) / 2

    # contour_len[ii] is contour length along the midpoint curve of point ii
    delta_s = [
        np.sqrt((m1xx[ii + 1] - m1xx[ii])**2 + (m1yy[ii + 1] - m1yy[ii])**2 +
                (m1zz[ii + 1] - m1zz[ii])**2) for ii in range(len(bpi) - 1)
    ]
    contour_len = np.cumsum(delta_s)
    contour_len = np.insert(contour_len, 0, 0)

    # ss is a linear sequence from first contour length element (which is 0) to last contour length element inclusive
    ss = np.linspace(contour_len[0], contour_len[-1], npoints)

    # get the splines as a function of contour length
    msxx = scipy.interpolate.splrep(contour_len, m1xx, k=3, s=0, per=circular)
    msyy = scipy.interpolate.splrep(contour_len, m1yy, k=3, s=0, per=circular)
    mszz = scipy.interpolate.splrep(contour_len, m1zz, k=3, s=0, per=circular)

    xx = scipy.interpolate.splev(ss, msxx)
    yy = scipy.interpolate.splev(ss, msyy)
    zz = scipy.interpolate.splev(ss, mszz)

    # find the tangent of the midpoint spline.
    # the tangent t(s) is d/ds [r(s)], where r(s) = (mxx(s), myy(s), mzz(s)). So the tangent is t(s) = d/ds [r(s)] = (d/ds [mxx(s)], d/ds [myy(s)], d/ds [mzz(s)])
    # get discrete array of normalised tangent vectors; __call__(xxx, 1) returns the first derivative
    # the tangent vector is a unit vector
    dmxx = scipy.interpolate.splev(ss, msxx, 1)
    dmyy = scipy.interpolate.splev(ss, msyy, 1)
    dmzz = scipy.interpolate.splev(ss, mszz, 1)

    tt = list(range(len(ss)))
    for ii in range(len(ss)):
        tt[ii] = np.array([dmxx[ii], dmyy[ii], dmzz[ii]])

    # get the normal vector via n(s) = dt(s)/ds = d^2/ds^2[r(s)]
    ddmxx = scipy.interpolate.splev(ss, msxx, 2)
    ddmyy = scipy.interpolate.splev(ss, msyy, 2)
    ddmzz = scipy.interpolate.splev(ss, mszz, 2)
    nn = list(range(len(ss)))
    for ii in range(len(ss)):
        nn[ii] = np.array([ddmxx[ii], ddmyy[ii], ddmzz[ii]])

    # we also need the 'normal' vector u(s) which points between the base pairs. (or between the spline fits through the backbones in this case)
    # n.b. these uxx, uyy, uzz are not normalised
    uxx_bpi = scipy.interpolate.splev(bpi, s2xx) - scipy.interpolate.splev(
        bpi, s1xx)
    uyy_bpi = scipy.interpolate.splev(bpi, s2yy) - scipy.interpolate.splev(
        bpi, s1yy)
    uzz_bpi = scipy.interpolate.splev(bpi, s2zz) - scipy.interpolate.splev(
        bpi, s1zz)

    # get the normal vector spline as a function of contour length
    suxx = scipy.interpolate.splrep(contour_len,
                                    uxx_bpi,
                                    k=3,
                                    s=0,
                                    per=circular)
    suyy = scipy.interpolate.splrep(contour_len,
                                    uyy_bpi,
                                    k=3,
                                    s=0,
                                    per=circular)
    suzz = scipy.interpolate.splrep(contour_len,
                                    uzz_bpi,
                                    k=3,
                                    s=0,
                                    per=circular)

    # evaluate the normal vector spline as a function of contour length
    uxx = scipy.interpolate.splev(ss, suxx)
    uyy = scipy.interpolate.splev(ss, suyy)
    uzz = scipy.interpolate.splev(ss, suzz)

    uu = list(range(len(ss)))
    for ii in list(range(len(ss))):
        uu[ii] = np.array([uxx[ii], uyy[ii], uzz[ii]])
        uu[ii] = uu[ii] - np.dot(tt[ii], uu[ii]) * tt[ii]
        # the normal vector should be normalised
        uu[ii] = tools.norm(uu[ii])

    # and finally we need the derivatives of that vector u(s). It takes a bit of work to get a spline of the normalised version of u from the unnormalised one
    nuxx = [vec[0] for vec in uu]
    nuyy = [vec[1] for vec in uu]
    nuzz = [vec[2] for vec in uu]
    nusxx = scipy.interpolate.splrep(ss, nuxx, k=3, s=0, per=circular)
    nusyy = scipy.interpolate.splrep(ss, nuyy, k=3, s=0, per=circular)
    nuszz = scipy.interpolate.splrep(ss, nuzz, k=3, s=0, per=circular)
    duxx = scipy.interpolate.splev(ss, nusxx, 1)
    duyy = scipy.interpolate.splev(ss, nusyy, 1)
    duzz = scipy.interpolate.splev(ss, nuszz, 1)
    duu = list(range(len(ss)))
    for ii in list(range(len(ss))):
        duu[ii] = np.array([duxx[ii], duyy[ii], duzz[ii]])

    ds = float(contour_len[-1] - contour_len[0]) / (npoints - 1)
    # do the integration w.r.t. s

    twist, writhe = discrete_dbl_int(tt,
                                     uu,
                                     duu,
                                     xx,
                                     yy,
                                     zz,
                                     ds,
                                     ds,
                                     ss,
                                     circular=True)

    return twist, writhe
예제 #12
0
def collision(polygons):
    """Adjusts velocity and rotation of two polygon objects if they collide"""

    intersectionSets = isIntersecting(polygons)
    if not intersectionSets[0] and not intersectionSets[1]:
        return

    # Computing velocity flip direction
    flipVectors = flipPoints(polygons, intersectionSets)

    # Alternate method of computing rebound direction, doesnt work well
    if False:
        flipVectors = flipPolys(polygons)
        if flipVectors[0]:
            return

        flipVectors = [
            sum(flipVectors[0]) / len(flipVectors[0]),
            sum(flipVectors[1]) / len(flipVectors[1])
        ]  # Repalce with some sort of averaging rather than taking just the first entry

    # Flipping Velocity
    newVelocities = []
    for basis1, polygon in zip(flipVectors, polygons):
        flipA, flipB, basis2 = projection_single(
            basis1, polygon.velocity + polygon.angVelocity)
        newVelocities.append(-abs(flipA) * basis1 + flipB * basis2)

    # Normalising velocity to conserve momentum and kinetic energy
    oldVelMags = [norm(polygons[i].velocity) for i in range(2)
                  ]  # Magnitude of velocities before collision
    newVelMags = [
        norm(newVelocities[i]) for i in range(2)
    ]  # When flipped the magnitude of the vectors change to these values
    reboundVelMags = rebound(
        [polygons[0].mass, polygons[1].mass], oldVelMags
    )  # New vectors should have these magnitudes to conserve momentum

    if newVelMags[0] == 0:
        newVelocities[0] = 0
    else:
        newVelocities[0] = newVelocities[0] * reboundVelMags[0] / newVelMags[0]
    if newVelMags[1] == 0:
        newVelocities[1] = 0
    else:
        newVelocities[1] = newVelocities[1] * reboundVelMags[1] / newVelMags[1]

    #normConstants = [reboundVelMags[i]/newVelMags[i] for i in range(2)] # Values to use to normalise the flipped vectors to the correct magnitudes
    #newVelocities = [newVelocities[i]*normConstants[i] for i in range(2)]

    # Computing rotation
    torques = []
    for i in range(2):
        torques.append(
            collisionTorque(polygons[i], newVelocities[i],
                            intersectionSets[i]))

    if not intersectionSets[0]:
        torques[0] = -1 * torques[1]
    elif not intersectionSets[1]:
        torques[1] = -1 * torques[0]

    # Setting new velocity and rotation
    for i in range(2):
        polygons[i].velocity = newVelocities[i]
        polygons[i].angVelocity = torques[i]
예제 #13
0
def main():

    print('Estimate transition matirx......Waiting......')

    for epoch in range(args.n_epoch_estimate):

        print('epoch {}'.format(epoch + 1))
        model.train()
        train_loss = 0.
        train_acc = 0.
        val_loss = 0.
        val_acc = 0.

        for batch_x, batch_y in train_loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
            optimizer_es.zero_grad()
            out = model(batch_x, revision=False)
            loss = loss_func_ce(out, batch_y)
            train_loss += loss.item()
            pred = torch.max(out, 1)[1]
            train_correct = (pred == batch_y).sum()
            train_acc += train_correct.item()
            loss.backward()
            optimizer_es.step()

        torch.save(model.state_dict(),
                   model_save_dir + '/' + 'epoch_%d.pth' % (epoch + 1))
        print('Train Loss: {:.6f}, Acc: {:.6f}'.format(
            train_loss / (len(train_data)) * args.batch_size,
            train_acc / (len(train_data))))

        with torch.no_grad():
            model.eval()
            for batch_x, batch_y in val_loader:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                out = model(batch_x, revision=False)
                loss = loss_func_ce(out, batch_y)
                val_loss += loss.item()
                pred = torch.max(out, 1)[1]
                val_correct = (pred == batch_y).sum()
                val_acc += val_correct.item()

        print('Val Loss: {:.6f}, Acc: {:.6f}'.format(
            val_loss / (len(val_data)) * args.batch_size,
            val_acc / (len(val_data))))
        val_acc_list.append(val_acc / (len(val_data)))

        with torch.no_grad():
            model.eval()
            for index, (batch_x, batch_y) in enumerate(estimate_loader):
                batch_x = batch_x.cuda()
                out = model(batch_x, revision=False)
                out = F.softmax(out, dim=1)
                out = out.cpu()
                if index <= index_num:
                    A[epoch][index * args.batch_size:(index + 1) *
                             args.batch_size, :] = out
                else:
                    A[epoch][index_num * args.batch_size,
                             len(train_data), :] = out

    val_acc_array = np.array(val_acc_list)
    model_index = np.argmax(val_acc_array)

    A_save_dir = prob_save_dir + '/' + 'prob.npy'
    np.save(A_save_dir, A)
    prob_ = np.load(A_save_dir)

    transition_matrix_ = tools.fit(prob_[model_index, :, :], args.num_classes,
                                   estimate_state)
    transition_matrix = tools.norm(transition_matrix_)

    matrix_path = matrix_save_dir + '/' + 'transition_matrix.npy'
    np.save(matrix_path, transition_matrix)
    T = torch.from_numpy(transition_matrix).float().cuda()

    True_T = tools.transition_matrix_generate(noise_rate=args.noise_rate,
                                              num_classes=args.num_classes)
    estimate_error = tools.error(T.cpu().numpy(), True_T)
    print('The estimation error is %s' % (estimate_error))
    # initial parameters

    estimate_model_path = model_save_dir + '/' + 'epoch_%s.pth' % (
        model_index + 1)
    estimate_model_path = torch.load(estimate_model_path)
    model.load_state_dict(estimate_model_path)

    print('Estimate finish.....Training......')
    val_acc_list_r = []

    for epoch in range(args.n_epoch):
        print('epoch {}'.format(epoch + 1))
        # training-----------------------------
        train_loss = 0.
        train_acc = 0.
        val_loss = 0.
        val_acc = 0.
        eval_loss = 0.
        eval_acc = 0.
        scheduler.step()
        model.train()
        for batch_x, batch_y in train_loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
            optimizer.zero_grad()
            out = model(batch_x, revision=False)
            prob = F.softmax(out, dim=1)
            prob = prob.t()
            loss = loss_func_reweight(out, T, batch_y)
            out_forward = torch.matmul(T.t(), prob)
            out_forward = out_forward.t()
            train_loss += loss.item()
            pred = torch.max(out_forward, 1)[1]
            train_correct = (pred == batch_y).sum()
            train_acc += train_correct.item()
            loss.backward()
            optimizer.step()
        with torch.no_grad():
            model.eval()
            for batch_x, batch_y in val_loader:
                model.eval()
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                out = model(batch_x, revision=False)
                prob = F.softmax(out, dim=1)
                prob = prob.t()
                loss = loss_func_reweight(out, T, batch_y)
                out_forward = torch.matmul(T.t(), prob)
                out_forward = out_forward.t()
                val_loss += loss.item()
                pred = torch.max(out_forward, 1)[1]
                val_correct = (pred == batch_y).sum()
                val_acc += val_correct.item()

        torch.save(model.state_dict(),
                   model_save_dir + '/' + 'epoch_r%d.pth' % (epoch + 1))
        print('Train Loss: {:.6f}, Acc: {:.6f}'.format(
            train_loss / (len(train_data)) * args.batch_size,
            train_acc / (len(train_data))))
        print('Val Loss: {:.6f}, Acc: {:.6f}'.format(
            val_loss / (len(val_data)) * args.batch_size,
            val_acc / (len(val_data))))
        val_acc_list_r.append(val_acc / (len(val_data)))

        with torch.no_grad():
            model.eval()
            for batch_x, batch_y in test_loader:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                out = model(batch_x, revision=False)
                loss = loss_func_ce(out, batch_y)
                eval_loss += loss.item()
                pred = torch.max(out, 1)[1]
                eval_correct = (pred == batch_y).sum()
                eval_acc += eval_correct.item()

            print('Test Loss: {:.6f}, Acc: {:.6f}'.format(
                eval_loss / (len(test_data)) * args.batch_size,
                eval_acc / (len(test_data))))

    val_acc_array_r = np.array(val_acc_list_r)
    reweight_model_index = np.argmax(val_acc_array_r)

    reweight_model_path = model_save_dir + '/' + 'epoch_r%s.pth' % (
        reweight_model_index + 1)
    reweight_model_path = torch.load(reweight_model_path)
    model.load_state_dict(reweight_model_path)
    nn.init.constant_(model.T_revision.weight, 0.0)

    print('Revision......')

    for epoch in range(args.n_epoch_revision):

        print('epoch {}'.format(epoch + 1))
        # training-----------------------------
        train_loss = 0.
        train_acc = 0.
        val_loss = 0.
        val_acc = 0.
        eval_loss = 0.
        eval_acc = 0.
        model.train()
        for batch_x, batch_y in train_loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
            optimizer_revision.zero_grad()
            out, correction = model(batch_x, revision=True)
            prob = F.softmax(out, dim=1)
            prob = prob.t()
            loss = loss_func_revision(out, T, correction, batch_y)
            out_forward = torch.matmul((T + correction).t(), prob)
            out_forward = out_forward.t()
            train_loss += loss.item()
            pred = torch.max(out_forward, 1)[1]
            train_correct = (pred == batch_y).sum()
            train_acc += train_correct.item()
            loss.backward()
            optimizer_revision.step()

        with torch.no_grad():
            model.eval()
            for batch_x, batch_y in val_loader:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                out, correction = model(batch_x, revision=True)
                prob = F.softmax(out, dim=1)
                prob = prob.t()
                loss = loss_func_revision(out, T, correction, batch_y)
                out_forward = torch.matmul((T + correction).t(), prob)
                out_forward = out_forward.t()
                val_loss += loss.item()
                pred = torch.max(out_forward, 1)[1]
                val_correct = (pred == batch_y).sum()
                val_acc += val_correct.item()

        estimate_error = tools.error(True_T,
                                     (T + correction).cpu().detach().numpy())
        print('Estimate error: {:.6f}'.format(estimate_error))
        print('Train Loss: {:.6f}, Acc: {:.6f}'.format(
            train_loss / (len(train_data)) * args.batch_size,
            train_acc / (len(train_data))))
        print('Val Loss: {:.6f}, Acc: {:.6f}'.format(
            val_loss / (len(val_data)) * args.batch_size,
            val_acc / (len(val_data))))

        with torch.no_grad():
            for batch_x, batch_y in test_loader:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                out, _ = model(batch_x, revision=True)
                loss = loss_func_ce(out, batch_y)
                eval_loss += loss.item()
                pred = torch.max(out, 1)[1]
                eval_correct = (pred == batch_y).sum()
                eval_acc += eval_correct.item()

            print('Test Loss: {:.6f}, Acc: {:.6f}'.format(
                eval_loss / (len(test_data)) * args.batch_size,
                eval_acc / (len(test_data))))