HEAD = nn.DataParallel(HEAD, device_ids = GPU_ID)
		HEAD = HEAD.to(DEVICE)
	else:
		# single-GPU setting
		BACKBONE = BACKBONE.to(DEVICE)
		HEAD = HEAD.to(DEVICE)


	#======= train & validation & save checkpoint =======#
	DISP_FREQ = len(train_loader) // 100 # frequency to display training loss & acc

	NUM_EPOCH_WARM_UP = NUM_EPOCH // 25  # use the first 1/25 epochs to warm up
	NUM_BATCH_WARM_UP = len(train_loader) * NUM_EPOCH_WARM_UP  # use the first 1/25 epochs to warm up
	batch = 0  # batch index

	for epoch in range(NUM_EPOCH): # start training process
		
		if epoch == STAGES[0]: # adjust LR for each training stage after warm up, you can also choose to adjust LR manually (with slight modification) once plaueau observed
			schedule_lr(OPTIMIZER)
		if epoch == STAGES[1]:
			schedule_lr(OPTIMIZER)
		if epoch == STAGES[2]:
			schedule_lr(OPTIMIZER)

		BACKBONE.train()  # set to training mode
		HEAD.train()

		OneEpoch(epoch, train_loader, OPTIMIZER, DISP_FREQ, NUM_EPOCH_WARM_UP, NUM_BATCH_WARM_UP)

		# training statistics per epoch (buffer for visualization)
    batch = 0  # batch index
    for epoch in range(NUM_EPOCH):  # start training process

        BACKBONE.train()  # set to training mode
        HEAD.train()

        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        for inputs, labels in tqdm(iter(train_loader)):

            if batch == STAGES[
                0]:  # adjust LR for each training stage after warm up, you can also choose to adjust LR manually (with slight modification) once plaueau observed
                schedule_lr(OPTIMIZER)
            if batch == STAGES[1]:
                schedule_lr(OPTIMIZER)
            if batch == STAGES[2]:
                schedule_lr(OPTIMIZER)

            # compute output
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE).long()
            features = BACKBONE(inputs)
            outputs = HEAD(features, labels)
            loss = LOSS(outputs, labels)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, labels, topk=(1, 5))
            losses.update(loss.data.item(), inputs.size(0))