Exemplo n.º 1
0
def train():
    data_path = 'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/commom1'  # '../defect_generate_gan/defect_samples'
    flaw_path = [
        'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/flaw01',
        'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/flaw02',
        'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/flaw05',
        'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/flaw07',
        'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/other',
        'C:/Users/DeepLearning/Desktop/flaw_detect/img/test/watermark'
    ]
    time_stamp = "{0:%Y-%m-%d-%H-%M-%S}".format(datetime.now())
    model = SimpleAEGAN(input_size=input_size, ae_level=ae_level).to(device)
    optim_G = torch.optim.Adam(model.ae.parameters(), lr=lr)
    optim_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr)
    lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optim_G, 3, 0.1)
    lr_scheduler_D = torch.optim.lr_scheduler.StepLR(optim_D, 3, 0.1)
    criterion_G = torch.nn.MSELoss()
    criterion_D = torch.nn.BCELoss()
    writer = SummaryWriter(log_dir='runs/aegan'+time_stamp)

    for e in range(epoch):
        img_paths = os.listdir(data_path)

        # Warm up Train 暖身训练 让lr从一个很小的值线性增长到初始设定的lr
        lr_warmup_G = None
        lr_warmup_D = None
        if e == 0:
            warmup_factor = 1. / 1000
            warmup_iters = min(1000, len(img_paths) - 1)
            lr_warmup_G = warmup_lr_scheduler(optim_G, warmup_iters, warmup_factor)
            lr_warmup_D = warmup_lr_scheduler(optim_D, warmup_iters, warmup_factor)
        losses = {'recon_loss': [],
                  'lie_loss': [],
                  'detect_loss': [],
                  'real_loss': []}
        for index, img_path in tqdm(enumerate(img_paths)):
            img = Image.open(os.path.join(data_path, img_path)).convert("RGB")
            img_tensor = trans2tensor(img).unsqueeze(0).to(device)
            # 添加高斯噪声
            noise_img_tensor = add_noise(img_tensor)

            # hard labels
            gt_recon_hard = torch.ones((img_tensor.shape[0], 1)).to(device)
            gt_real_hard = torch.zeros((img_tensor.shape[0], 1)).to(device)
            target_recon_hard = torch.zeros((img_tensor.shape[0], 1)).to(device)

            # soft labels
            gt_recon_soft = torch.tensor(np.random.uniform(0.7, 1.0, (img_tensor.shape[0], 1)), dtype=torch.float).to(device)
            gt_real_soft = torch.tensor(np.random.uniform(0.0, 0.3, (img_tensor.shape[0], 1)), dtype=torch.float).to(device)
            target_recon_soft = torch.tensor(np.random.uniform(0.0, 0.3, (img_tensor.shape[0], 1)), dtype=torch.float).to(device)

            global_step = e * len(img_paths) + index

            # 训练判别器
            # for repeat in range(2):
            real_prob = model(img_tensor, True)
            real_loss = criterion_D(real_prob, gt_real_hard)
            loss_D = real_loss
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

            _, recon_prob2 = model(noise_img_tensor)
            detect_loss = criterion_D(recon_prob2, gt_recon_soft)
            loss_D = detect_loss
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

            # 训练AE生成器
            # if global_step % 50 == 0:
            recon_img, _ = model(noise_img_tensor)
            recon_loss = criterion_G(recon_img, img_tensor)
            loss_G = recon_loss
            optim_G.zero_grad()
            loss_G.backward()
            optim_G.step()

            _, recon_prob = model(noise_img_tensor)
            lie_loss = criterion_D(recon_prob, target_recon_soft)
            loss_G = lie_loss
            optim_G.zero_grad()
            loss_G.backward()
            optim_G.step()


            if lr_warmup_G is not None:
                lr_warmup_G.step()
            if lr_warmup_D is not None:
                lr_warmup_D.step()

            # 输出内容到tensorboard观察训练情况
            losses['recon_loss'].append(recon_loss)
            losses['lie_loss'].append(lie_loss)
            losses['detect_loss'].append(detect_loss)
            losses['real_loss'].append(real_loss)

            if global_step % show_interval == 0 and global_step != 0:
                recon_img_nonoise, recon_prob_nonoise = model(img_tensor.detach())
                writer.add_scalars('loss', {'recon_loss': sum(losses['recon_loss']) / show_interval,
                                                'lie_loss': sum(losses['lie_loss']) / show_interval,
                                                'detect_loss': sum(losses['detect_loss']) / show_interval,
                                                'real_loss': sum(losses['real_loss']) / show_interval}, global_step // show_interval)
                # writer.add_scalar('loss_lie_recon', lie_loss / recon_loss, global_step // show_interval)
                # writer.add_scalar('loss_real_detect', real_loss / detect_loss, global_step // show_interval)
                writer.add_image('common/origin', img_tensor.squeeze(0)/2+0.5, global_step // show_interval)
                writer.add_image('common/noise', noise_img_tensor.squeeze(0) / 2 + 0.5, global_step // show_interval)
                writer.add_image('common/recon', recon_img.squeeze(0)/2+0.5, global_step // show_interval)
                writer.add_image('common/recon_nonoise', recon_img_nonoise.squeeze(0) / 2 + 0.5, global_step // show_interval)
                writer.add_text('judge_common',
                                f'recon:{"{:.2%}".format(float(recon_prob))}_'
                                f'orgin:{"{:.2%}".format(float(real_prob))}_'
                                f'recon_nonoise:{"{:.2%}".format(float(recon_prob_nonoise))}',
                                global_step // show_interval)
                # 重新累计
                losses = {'recon_loss': [],
                          'lie_loss': [],
                          'detect_loss': [],
                          'real_loss': []}
                # 输出网络权重和梯度
                # for tag, value in model.named_parameters():
                #     tag = tag.replace('.', '/')
                #     writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step // show_interval)
                #     writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step // show_interval)
                # 对瑕疵进行测试
                for p in flaw_path:
                    get_img_paths = os.listdir(p)
                    get_one = random.choice(get_img_paths)
                    get_img = Image.open(os.path.join(p, get_one)).convert("RGB")
                    get_img_tensor = trans2tensor(get_img).unsqueeze(0).to(device)
                    # noise_get_img_tensor = add_noise(get_img_tensor)
                    get_recon_flaw, get_flaw_prob1 = model(get_img_tensor)
                    get_flaw_prob2 = model(get_img_tensor.detach(), True)
                    writer.add_image(f'recon_flaw/{p.split("/")[-1]}_o', get_img_tensor.squeeze(0)/2+0.5, global_step // show_interval)
                    writer.add_image(f'recon_flaw/{p.split("/")[-1]}', get_recon_flaw.squeeze(0)/2+0.5, global_step // show_interval)
                    writer.add_text(f'judge_flaw/{p.split("/")[-1]}',
                    f'recon:{"{:.2%}".format(float(get_flaw_prob1))}_orgin:{"{:.2%}".format(float(get_flaw_prob2))}',
                                    global_step // show_interval)
                    if global_step % (show_interval * 10) == 0 and global_step != 0:
                        random.shuffle(get_img_paths)
                        test_imgs = get_img_paths[:200]
                        right = 0
                        probs = []
                        for test_img in test_imgs:
                            get_test_img = Image.open(os.path.join(p, test_img)).convert("RGB")
                            get_test_img_tensor = trans2tensor(get_test_img).unsqueeze(0).to(device)
                            out = model(get_test_img_tensor, True)
                            if float(out) > 0.5:
                                right += 1
                                probs.append(float(out))
                        writer.add_scalars(f'flaw_acc/{p.split("/")[-1]}', {'acc': right/200,
                                                                            'summup': sum(probs)/200}, global_step // (show_interval * 10))
        lr_scheduler_G.step()
        lr_scheduler_D.step()
Exemplo n.º 2
0
class TensorBoardCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
    <https://www.tensorflow.org/tensorboard>`__.

    Args:
        tb_writer (:obj:`SummaryWriter`, `optional`):
            The writer to use. Will instantiate one if not set.
    """
    def __init__(self, tb_writer=None):
        assert (
            _has_tensorboard
        ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
        self.tb_writer = tb_writer

    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
        self.tb_writer = SummaryWriter(log_dir=log_dir)

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        self._init_summary_writer(args, log_dir)

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
            self.tb_writer.add_hparams(args.to_sanitized_dict(),
                                       metric_dict={})

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero:
            if self.tb_writer is None:
                self._init_summary_writer(args)

        if self.tb_writer:
            logs = rewrite_logs(logs)
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
Exemplo n.º 3
0
rm_sz = eval(exp_config['rm_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
reg_lambda = eval(exp_config['reg_lambda'])
scenario = eval(exp_config['scenario'])
sub_dir = scenario

# setting up log dir for tensorboard
log_dir = 'logs/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)

# Other variables init
tot_it_step = 0
rm = None

# do not remove this line
start_time = time.time()

# Create the dataset object
dataset = CORE50(
    root='/home/admin/ssd_data/cvpr_competition/cvpr_competition_data/',
    scenario=scenario,
    preload=False)
preproc = preprocess_imgs
Exemplo n.º 4
0
def main():
    import torch
    from torch.optim import lr_scheduler
    import torch.optim as optim
    from torch.autograd import Variable
    from trainer import fit
    import numpy as np
    cuda = torch.cuda.is_available()
    # Training settings

    parser = argparse.ArgumentParser(
        description='cross subject domain adaptation')

    parser.add_argument('--batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=True,
                        help='For Saving the current Model')

    # Writer will output to ./runs/ directory by default

    fold_idx = 4
    gamma = 0.7
    margin = 1.0

    DAsetting = False
    args = parser.parse_args()
    args.seed = 0
    args.use_tensorboard = True
    args.save_model = True
    n_epochs = 200
    startepoch = 0

    folder_name = 'exp11_0630'
    comment = 'deep4' + str(fold_idx) + '_g_' + str(gamma) + '_m_' + str(
        margin)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if use_cuda else "cpu")
    #kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    from datetime import datetime
    import os
    loging = False

    x_data, y_data = load_smt()
    x_data = x_data[:, :, :, 100:]
    #get subject number
    y_subj = np.zeros([108, 200])
    for i in range(108):
        y_subj[i, :] = i * 2
    y_subj = y_data.reshape(108, 200) + y_subj
    y_subj = y_subj.reshape(21600)
    #y_subj = np.concatenate([y_data,y_subj],axis=1)

    # plt.imshow(x_data[100,0,:,:])
    # For classification data
    valtype = 'subj'
    # if x_data.shape[2] != 60:
    #     x_data = x_data[:,:,2:,:]
    # plt.imshow(x_data[1000,0,:,:])
    # #subj - 0-27 train
    # train_subj1 = np.r_[0:27]
    # train_subj2 = np.r_[0:27]+54
    #
    # test_subj = np.r_[27:54,54+27:108]

    #chidx = np.r_[7:11, 12:15, 17:21, 32:41] #오연조건
    # chidx = np.r_[2:56, 60:62]
    # x_data = x_data[:,:,chidx,:]

    # For Domain adaptation setting
    if DAsetting:
        # test_subj = np.r_[fold_idx * 9:fold_idx * 9 + 9, fold_idx * 9 + 54:fold_idx * 9 + 9 + 54]
        test_subj_id = 39
        test_subj = np.r_[test_subj_id:test_subj_id + 1]
        train_subj1 = np.setxor1d(np.r_[0:108], test_subj)
        train_subj2 = test_subj

        n_targets = 60
        trial_s = (0, 200)
        trial_t = (0, n_targets)

        trial_val = (n_targets, 200)

        # dataset_train1 = GigaDataset(x=x_data, y=y_data, valtype=valtype, istrain=True,subj=train_subj1,trial=trial_s)
        dataset_train = GigaDataset(x=x_data,
                                    y=y_data,
                                    valtype=valtype,
                                    istrain=True,
                                    subj=train_subj2,
                                    trial=trial_t)
        # dataset_train = dataset_train1.__add__(dataset_train2)
        dataset_test = GigaDataset(x=x_data,
                                   y=y_data,
                                   valtype=valtype,
                                   istrain=False,
                                   subj=test_subj,
                                   trial=trial_val)

        triplet_dataset_train = TripletGigaDA(x=x_data,
                                              y=y_subj,
                                              valtype=valtype,
                                              istrain=True,
                                              subj_s=train_subj1,
                                              trial_s=trial_s,
                                              subj_t=train_subj2,
                                              trial_t=trial_t)

        # triplet_dataset_train2 = TripletGiga2(x=x_data, y=y_subj, valtype=valtype, istrain=True, subj=train_subj2, trial=trial_t)
        # triplet_dataset_train = triplet_dataset_train1.__add__(triplet_dataset_train2)

        triplet_dataset_test = TripletGigaDA(x=x_data,
                                             y=y_subj,
                                             valtype=valtype,
                                             istrain=True,
                                             subj_s=train_subj1,
                                             trial_s=trial_s,
                                             subj_t=test_subj,
                                             trial_t=trial_val)

    else:  #DG setting
        # test_subj = np.r_[fold_idx*9:fold_idx*9+9,fold_idx*9+54:fold_idx*9+9+54]
        # train_subj = test_subj
        # trial_train = (0, 30)
        # trial_val = (30, 200)
        #
        # bci_excellent = np.r_[43, 20, 27, 1, 28, 32, 35, 44, 36, 2]
        # bci_excellent = np.concatenate([bci_excellent, bci_excellent + 54])

        test_subj = np.r_[fold_idx * 9:fold_idx * 9 + 9,
                          fold_idx * 9 + 54:fold_idx * 9 + 9 + 54]
        # train_subj = np.setdiff1d(bci_excellent, test_subj)
        # bci_excellent.sort()

        print('test subj:' + str(test_subj))
        train_subj = np.setdiff1d(np.r_[0:108], test_subj)

        trial_train = (0, 200)
        trial_val = (0, 200)

        dataset_train = GigaDataset(x=x_data,
                                    y=y_data,
                                    valtype=valtype,
                                    istrain=True,
                                    subj=train_subj,
                                    trial=trial_train)
        dataset_test = GigaDataset(x=x_data,
                                   y=y_data,
                                   valtype=valtype,
                                   istrain=False,
                                   subj=test_subj,
                                   trial=trial_val)

        triplet_dataset_train = TripletGiga2(x=x_data,
                                             y=y_subj,
                                             valtype=valtype,
                                             istrain=True,
                                             subj=train_subj,
                                             trial=trial_train)
        # triplet_dataset_train2 = TripletGiga2(x=x_data[:,:,:,10:], y=y_subj, valtype=valtype, istrain=True, subj=train_subj,
        #                                      trial=trial_train)
        # triplet_dataset_train = triplet_dataset_train1.__add__(triplet_dataset_train2)

        triplet_dataset_test = TripletGiga2(x=x_data,
                                            y=y_subj,
                                            valtype=valtype,
                                            istrain=False,
                                            subj=test_subj,
                                            trial=trial_val)

    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=args.batch_size,
                                              shuffle=False)
    triplet_train_loader = torch.utils.data.DataLoader(
        triplet_dataset_train, batch_size=args.batch_size, shuffle=True)
    triplet_test_loader = torch.utils.data.DataLoader(
        triplet_dataset_test, batch_size=args.batch_size, shuffle=False)

    ###################################################################################################################
    # make model for metric learning
    from networks import DWConvNet, basenet, Deep4Net_origin, Deep4Net, Deep4NetWs, EmbeddingDeep4CNN, EmbeddingDeep4CNN_bn, TripletNet, FineShallowCNN, EmbeddingDeepCNN, QuintupletNet, EmbeddingShallowCNN
    from losses import TripletLoss_dev2, TripLoss, ContrastiveLoss_dk

    if gamma == 1.0:
        model = Deep4Net_origin()
    else:
        embedding_net = Deep4Net()
        print(embedding_net)
        model = TripletNet(embedding_net)
    #exp3-1 fc레이어 한층더
    # model.fc = nn.Sequential(
    #     nn.Linear(model.num_hidden,128),
    #     nn.ReLU(),
    #     nn.Dropout(),
    #     nn.Linear(128,2)
    # )
    if cuda:
        model.cuda()
    loss_fn = TripletLoss_dev2(margin, gamma).cuda()

    log_interval = 10

    ##########################################################
    # optimizer = optim.Adam(model.parameters())
    optimizer = optim.SGD(model.parameters(),
                          lr=0.1,
                          momentum=0.9,
                          weight_decay=0.0005)
    milestones = [15, 30, 50, 120]
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=milestones,
                                         gamma=0.1)  # 너무 빨리 떨구면 언더피팅하는듯

    # exp1 : 62ch 0~5fold까지 셋팅
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, 5, gamma=0.5, last_epoch=-1)

    #exp2 : 운동영역주변 20ch, train성능이 fit하지 않는 현상이 g=0.7,1.0 양족에서 모두 나타나서, 기존의 러닝레이트보다 강하게 줘보고 실험코자함
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, 5, gamma=1.0, last_epoch=-1)
    # #
    # #exp4, exp5
    # optimizer = optim.SGD(model.parameters(), lr=0.005/gamma, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, 10, gamma=0.8, last_epoch=-1) #너무 빨리 떨구면 언더피팅하는듯

    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, 5, gamma=0.8, last_epoch=-1) #너무 빨리 떨구면 언더피팅하는듯

    # exp5
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, 10, gamma=0.5, last_epoch=-1)

    # exp7
    # optimizer = optim.SGD(model.parameters(), lr=0.005 / gamma, momentum=0.9)
    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[100,200], gamma=0.7)  # 너무 빨리 떨구면 언더피팅하는듯

    #model for validation
    evalmodel = nn.Sequential(model.embedding_net, model.fc,
                              nn.LogSoftmax(dim=1)).to(device)

    print('____________DANet____________')
    print(model)

    #save someting

    model_save_path = 'model/' + folder_name + '/' + comment + '/'
    if (args.save_model):
        if not os.path.isdir(model_save_path):
            os.makedirs(model_save_path)

    if loging:
        fname = model_save_path + datetime.today().strftime(
            "%m_%d_%H_%M") + ".txt"
        f = open(fname, 'w')

    if args.use_tensorboard:
        writer = SummaryWriter(comment=comment)
        writer.add_text('optimizer', str(optimizer))
        writer.add_text('scheduler', str(milestones))
        writer.add_text('model_save_path', model_save_path)
        writer.add_text('model', str(model))
        writer.close()
    # load_model_path = 'C:\\Users\dk\PycharmProjects\giga_cnn\model\deep100_negsubj\\fold_0_g_0.7\danet_0.7_49.pt'
    #'C:\\Users\dk\PycharmProjects\giga_cnn\구모델\\clf_83_8.pt'#'clf_29.pt' #'triplet_mg26.pt'#'clf_triplet2_5.pt' #'triplet_31.pt'
    # load_model_path = 'C:\\Users\dk\PycharmProjects\giga_cnn\model\exp6_basenet\\fold_0_g_0.6\danet_0.6_86.pt'

    if startepoch > 0:
        load_model_path = model_save_path + 'danet_' + str(gamma) + '_' + str(
            startepoch) + '.pt'
        model_save_path = model_save_path + '(cont)'
    else:
        load_model_path = None
    if load_model_path is not None:
        model.load_state_dict(torch.load(load_model_path))

    # for param in model.embedding_net.parameters():
    #     param.requires_grad = False

    epochidx = 1

    for epochidx in range(100):
        fit(triplet_train_loader, triplet_test_loader, model, loss_fn,
            optimizer, scheduler, epochidx, n_epochs, cuda, log_interval)
        print(epochidx)
        train_loss, train_score = eval(args, evalmodel, device, train_loader)
        eval_loss, eval_score = eval(args, evalmodel, device, test_loader)

        if args.use_tensorboard:
            writer.add_scalar('Train/Loss',
                              np.mean(train_loss) / args.batch_size, epochidx)
            writer.add_scalar('Train/Acc',
                              np.mean(train_score) / args.batch_size, epochidx)
            writer.add_scalar('Eval/Loss',
                              np.mean(eval_loss) / args.batch_size, epochidx)
            writer.add_scalar('Eval/Acc',
                              np.mean(eval_score) / args.batch_size, epochidx)
            writer.close()
        if args.save_model:
            torch.save(
                model.state_dict(), model_save_path + 'danet_' + str(gamma) +
                '_' + str(epochidx) + '.pt')
Exemplo n.º 5
0
def do_pretrain(args):
    if is_main_process(args) and args.tensorboard_dir:
        tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
        tb_writer.add_text("args", args.to_json_string())
        tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
    else:
        tb_writer = None

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    ort.set_seed(args.seed)

    device, args = setup_training(args)

    model = prepare_model(args, device)

    logger.info("Running training: Batch size = %d, initial LR = %f",
                args.train_batch_size, args.learning_rate)

    most_recent_ckpts_paths = []
    average_loss = 0.0
    epoch = 0
    training_steps = 0

    pool = ProcessPoolExecutor(1)
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and "training" in f
        ]
        files.sort()
        random.shuffle(files)

        f_id = 0
        train_dataloader, data_file = create_pretraining_dataset(
            get_data_file(f_id, args.world_rank, args.world_size, files),
            args.max_predictions_per_seq, args)

        for f_id in range(1, len(files)):
            logger.info("data file %s" % (data_file))

            dataset_future = pool.submit(
                create_pretraining_dataset,
                get_data_file(f_id, args.world_rank, args.world_size, files),
                args.max_predictions_per_seq,
                args,
            )

            train_iter = tqdm(train_dataloader, desc="Iteration"
                              ) if is_main_process(args) else train_dataloader
            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

                loss, _, _ = model.train_step(input_ids, input_mask,
                                              segment_ids, masked_lm_labels,
                                              next_sentence_labels)
                average_loss += loss.item()

                global_step = model._train_step_info.optimization_step
                if training_steps % (args.log_freq *
                                     args.gradient_accumulation_steps) == 0:
                    if is_main_process(args):
                        divisor = args.log_freq * args.gradient_accumulation_steps
                        if tb_writer:
                            lr = model.options.lr_scheduler.get_last_lr()[0]
                            tb_writer.add_scalar(
                                "train/summary/scalar/Learning_Rate", lr,
                                global_step)
                            if args.fp16:
                                tb_writer.add_scalar(
                                    "train/summary/scalar/loss_scale_25", loss,
                                    global_step)
                                # TODO: ORTTrainer to expose all_finite
                                # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step)
                            tb_writer.add_scalar("train/summary/total_loss",
                                                 average_loss / divisor,
                                                 global_step)

                        print("Step:{} Average Loss = {}".format(
                            global_step, average_loss / divisor))

                    if global_step >= args.max_steps or global_step >= force_to_stop_max_steps:
                        if tb_writer:
                            tb_writer.close()

                    if global_step >= args.max_steps:
                        if args.save_checkpoint:
                            model.save_checkpoint(
                                os.path.join(
                                    args.output_dir,
                                    "checkpoint-{}.ortcp".format(
                                        args.world_rank)))
                        final_loss = average_loss / (
                            args.log_freq * args.gradient_accumulation_steps)
                        return final_loss

                    average_loss = 0

            del train_dataloader

            train_dataloader, data_file = dataset_future.result(timeout=None)

        epoch += 1
Exemplo n.º 6
0
def training(retrain=None):
    print("Ok, we are ready to train. On your go.")

    breakpoint()

    if retrain is not None:
        checkpoint = torch.load(retrain, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])

    epochs = 100000
    reporting = 2
    accumulate = 4

    version = "DEC212020_1_NODEC"
    modelID = str(uuid.uuid4())[-5:]
    initialRuntime = time.time()

    writer = SummaryWriter(f'./training/movie/logs/{modelID}')

    # random.shuffle(zipped_dataset)

    model.train()  # duh
    for epoch in range(epochs):
        #
        #         if (epoch % 3 == 0) and epoch != 0:
        # print(f'Taking a 15 min fridge break before starting at {epoch}...')
        # for _ in tqdm(range(60*15)):
        # time.sleep(1)
        # print(f'Fridge break done. Let\'s get cracking on epoch {epoch}')

        checkpointID = str(uuid.uuid4())[-5:]
        batch_data_group = list(zip(inputs_batched, outputs_batched))

        random.shuffle(batch_data_group)

        batch_data_feed = tqdm(enumerate(batch_data_group),
                               total=len(inputs_batched))

        for batch, (inp, oup) in batch_data_feed:
            encinp_torch = np2tens(inp)
            decinp_torch = np2tens(oup)

            padding_row = torch.zeros(batch_size, 1)
            oup_torch = (torch.cat((np2tens(oup)[:, 1:], padding_row),
                                   dim=1)).long()

            prediction = model(encinp_torch, decinp_torch, None,
                               int(batch_size))

            target_mask = torch.not_equal(oup_torch, 0).float()
            # loss_matrix = torch.mean((prediction-torch.nn.functional.one_hot(oup_torch, len(vocabulary)))**2, 2)
            # loss_val = torch.mean(target_mask*loss_matrix)

            # powered_value = torch.pow(prediction-oup_vector, 2)
            # loss_val = torch.mean(target_mask.unsqueeze(-1).expand_as(powered_value)*powered_value)

            loss_val = criterion(prediction, oup_torch, target_mask)

            #             target_mask = torch.not_equal(oup_torch, 0).float()
            # loss_matrix = torch.mean((prediction-torch.nn.functional.one_hot(oup_torch, len(vocabulary)))**2, 2)
            # loss_val = torch.mean(target_mask*loss_matrix)

            loss_val.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            prediction_values = np.array(torch.argmax(prediction, 2).cpu())[:1]

            if ((batch + (epoch * len(inputs_batched))) %
                    accumulate) == 0 and batch != 0:
                adam.step()
                adam.zero_grad()

            # prediction_values = np.array(torch.argmax(prediction,2).cpu())[:1]

            prediction_sentences = []
            for e in prediction_values:
                prediction_value = []
                for i in e:
                    try:
                        prediction_value.append(vocabulary_inversed[i])
                    except KeyError:
                        prediction_value.append("<err>")
                prediction_sentences.append(prediction_value)

            final_sent = ""
            for word in prediction_sentences[0]:
                final_sent = final_sent + word + " "

            writer.add_scalar('Train/loss', loss_val.item(),
                              batch + (epoch * len(inputs_batched)))
            writer.add_text('Train/sample', final_sent,
                            batch + (epoch * len(inputs_batched)))

            batch_data_feed.set_description(
                f'| Model: {modelID}@{checkpointID} | Epoch: {epoch} | Batch: {batch} | Loss: {loss_val:.5f} |'
            )
        #plot_grad_flow(model.named_parameters())

        # CheckpointID,ModelID,ModelVersion,Dataset,Initial Runtime,Current Time,Epoch,Loss,Checkpoint Filename

        initialHumanTime = datetime.fromtimestamp(initialRuntime).strftime(
            "%m/%d/%Y, %H:%M:%S")
        nowHumanTime = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")

        with open("./training/movie/training-log.csv", "a+") as df:
            csvfile = csv.writer(df)
            csvfile.writerow([
                checkpointID, modelID, version, dataset_name, initialHumanTime,
                nowHumanTime, epoch,
                loss_val.item(), f'{modelID}-{checkpointID}.model',
                f'{retrain}'
            ])

        torch.save(
            {
                'version': version,
                'modelID': modelID,
                'checkpointID': checkpointID,
                'datasetName': dataset_name,
                'epoch': epoch,
                'loss': loss_val,
                'model_state': model.state_dict(),
                'optimizer_state': adam.state_dict(),
                'lr': scheduler.get_last_lr()
            }, f'./training/movie/{modelID}-{checkpointID}.model')

        print(f'| EPOCH DONE | Epoch: {epoch} | Loss: {loss_val} |')
        scheduler.step()
    writer.close()
Exemplo n.º 7
0
def main():
    # Connecting TRAINS
    task = Task.init(project_name='ODSC20-east',
                     task_name='pytorch with tensorboard')

    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    writer = SummaryWriter('runs')
    writer.add_text('TEXT', 'This is some text', 0)
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        'data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        'data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = Net()
    if args.cuda:
        model.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(model, epoch, train_loader, args, optimizer, writer)
        torch.save(model, 'model_{}'.format(epoch))
    test(model, test_loader, args, optimizer, writer)
Exemplo n.º 8
0
            if 'bert' in n:
                p.requires_grad = False
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                     lr=config['model']['learning_rate'])

    for name, param in model.named_parameters():
        print(name, param.shape, param.device, param.requires_grad)

    max_step = config['model']['max_step']
    check_step = config['model']['check_step']
    batch_size = config['model']['batch_size']
    model.zero_grad()
    train_slot_loss, train_intent_loss = 0, 0
    best_val_f1 = 0.

    writer.add_text('config', json.dumps(config))

    for step in range(1, max_step + 1):
        model.train()
        batched_data = dataloader.get_train_batch(batch_size)
        batched_data = tuple(t.to(DEVICE) for t in batched_data)
        word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = batched_data
        if not config['model']['context']:
            context_seq_tensor, context_mask_tensor = None, None
        _, _, slot_loss, intent_loss = model.forward(word_seq_tensor, word_mask_tensor, tag_seq_tensor, tag_mask_tensor,
                                                     intent_tensor, context_seq_tensor, context_mask_tensor)
        train_slot_loss += slot_loss.item()
        train_intent_loss += intent_loss.item()
        loss = slot_loss + intent_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
Exemplo n.º 9
0
def train(model, params, print_log=False):
    model_name = "{}_lr{}".format(model.__name__, params.learning_rate)
    writer = SummaryWriter('{}/{}/'.format(mkSavePath('runs', params), model_name))
    optimizer = model.opt

    if params.step_lr:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=params.step_lr, gamma=0.1)

    criterion = nn.CrossEntropyLoss()
    best_model = copy.deepcopy(model.state_dict())
    best_ep = 0
    max_acc = test_model(model, params)
    save_H = False

    for epoch in pbar(range(params.num_epochs)):
        for phase in ['train', 'test']:
            logs = {'Loss': 0.0, 'Accuracy': 0.0}
            # Set the model to the correct phase
            model.train() if phase == 'train' else model.eval()

            for images, labels in getattr(params, phase + '_loader'):
                # Move tensors to the configured device
                images = images.reshape(-1, 28 * 28).to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    # Forward pass
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    accuracy = torch.sum(torch.max(outputs, 1)[1] == labels.data).item()

                    # Update log
                    logs['Loss'] += images.shape[0] * loss.detach().item()
                    logs['Accuracy'] += accuracy

                    # Backward pass
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    if not save_H:
                        init_H = model.H.detach().cpu().numpy().diagonal()
                        max_H = None
                        save_H = True

            logs['Loss'] /= len(getattr(params, phase + '_loader').dataset)
            logs['Accuracy'] /= len(getattr(params, phase + '_loader').dataset)
            writer.add_scalars('Loss', {phase: logs['Loss']}, epoch+1)
            writer.add_scalars('Accuracy', {phase: logs['Accuracy']}, epoch+1)

            if print_log:
                print('\n Epoch [{}]: ({}) Loss = {:.6f}, Accuracy = {:.4f}%'
                      .format(epoch+1, phase, logs['Loss'], logs['Accuracy']*100))

            if phase == 'test' and logs['Accuracy'] > max_acc:
                max_acc = logs['Accuracy']
                best_ep = epoch + 1
                best_model = copy.deepcopy(model.state_dict())
                max_H = model.H.detach().cpu().numpy().diagonal()

        if params.step_lr:
            scheduler.step()

    # write to tensor board
    writer.add_text('Best_Accuracy', str(max_acc), best_ep)
    writer.add_histogram('init_H', init_H)
    writer.add_histogram('max_H', max_H, best_ep)

    # save model
    PATH = '{}/{}.pt'.format(mkSavePath('model', params), model_name)
    torch.save(best_model, PATH)

    writer.close()
Exemplo n.º 10
0
def test_npg(args=get_args()):
    env, train_envs, test_envs = make_mujoco_env(args.task,
                                                 args.seed,
                                                 args.training_num,
                                                 args.test_num,
                                                 obs_norm=True)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low),
          np.max(env.action_space.high))
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # model
    net_a = Net(
        args.state_shape,
        hidden_sizes=args.hidden_sizes,
        activation=nn.Tanh,
        device=args.device,
    )
    actor = ActorProb(
        net_a,
        args.action_shape,
        max_action=args.max_action,
        unbounded=True,
        device=args.device,
    ).to(args.device)
    net_c = Net(
        args.state_shape,
        hidden_sizes=args.hidden_sizes,
        activation=nn.Tanh,
        device=args.device,
    )
    critic = Critic(net_c, device=args.device).to(args.device)
    torch.nn.init.constant_(actor.sigma_param, -0.5)
    for m in list(actor.modules()) + list(critic.modules()):
        if isinstance(m, torch.nn.Linear):
            # orthogonal initialization
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)
    # do last policy layer scaling, this will make initial actions have (close to)
    # 0 mean and std, and will help boost performances,
    # see https://arxiv.org/abs/2006.05990, Fig.24 for details
    for m in actor.mu.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.zeros_(m.bias)
            m.weight.data.copy_(0.01 * m.weight.data)

    optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

    def dist(*logits):
        return Independent(Normal(*logits), 1)

    policy = NPGPolicy(
        actor,
        critic,
        optim,
        dist,
        discount_factor=args.gamma,
        gae_lambda=args.gae_lambda,
        reward_normalization=args.rew_norm,
        action_scaling=True,
        action_bound_method=args.bound_action_method,
        lr_scheduler=lr_scheduler,
        action_space=env.action_space,
        advantage_normalization=args.norm_adv,
        optim_critic_iters=args.optim_critic_iters,
        actor_step_size=args.actor_step_size,
    )

    # load a previous policy
    if args.resume_path:
        ckpt = torch.load(args.resume_path, map_location=args.device)
        policy.load_state_dict(ckpt["model"])
        train_envs.set_obs_rms(ckpt["obs_rms"])
        test_envs.set_obs_rms(ckpt["obs_rms"])
        print("Loaded agent from: ", args.resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs)

    # log
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    args.algo_name = "npg"
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    # logger
    if args.logger == "wandb":
        logger = WandbLogger(
            save_interval=1,
            name=log_name.replace(os.path.sep, "__"),
            run_id=args.resume_id,
            config=args,
            project=args.wandb_project,
        )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    if args.logger == "tensorboard":
        logger = TensorboardLogger(writer)
    else:  # wandb
        logger.load(writer)

    def save_best_fn(policy):
        state = {
            "model": policy.state_dict(),
            "obs_rms": train_envs.get_obs_rms()
        }
        torch.save(state, os.path.join(log_path, "policy.pth"))

    if not args.watch:
        # trainer
        result = onpolicy_trainer(
            policy,
            train_collector,
            test_collector,
            args.epoch,
            args.step_per_epoch,
            args.repeat_per_collect,
            args.test_num,
            args.batch_size,
            step_per_collect=args.step_per_collect,
            save_best_fn=save_best_fn,
            logger=logger,
            test_in_train=False,
        )
        pprint.pprint(result)

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num,
                                    render=args.render)
    print(
        f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}'
    )
Exemplo n.º 11
0
class PytorchTrainer:
    def __init__(self, args):
        # Paths
        self.args = args
        self.logdir = os.path.join(args.base, args.logdir)
        self.fold = args.fold
        if not os.path.exists(self.logdir):
            os.makedirs(self.logdir)
        self.model_save_path = os.path.join(args.base, args.model_path)

        self.train_data_path = os.path.join(args.base, args.train_datapath)
        self.test_data_path = os.path.join(args.base, args.test_datapath)

        self.input_features = [args.axial_t2, args.coronal_t2, args.axial_pc]
        self.train_dataset = MRIDataset(self.train_data_path, True, args.record_shape, args.feature_shape, self.input_features)
        self.test_dataset = MRIDataset(self.test_data_path, False, args.record_shape, args.feature_shape, self.input_features, preprocess=True)

        self.summary = SummaryWriter(self.logdir, f'fold{self.fold}')

        self.write_log(f'Fold: {self.fold}', 0)

        if USE_GPU and torch.cuda.is_available():
            print("USING GPU")
            self.device = torch.device('cuda')
        else:
            print("USING CPU")
            self.device = torch.device('cpu')

        # Data processing
        self.record_shape = args.record_shape

        # General parameters
        self.test_evaluation_period = 1
        self.num_batches = int(args.num_batches)

        # Network parameters
        self.attention = args.attention
        self.feature_shape = args.feature_shape
        self.batch_size = args.batch_size
        self.test_size = min(self.batch_size, len(self.test_dataset))

        # Hyperparameters
        self.dropout_train_prob = 0.5
        starter_learning_rate = 5e-6
        self.learning_rate = starter_learning_rate
        self.train_loader = DataLoader(self.train_dataset, self.batch_size, True, num_workers=4, persistent_workers=True)
        self.test_loader = DataLoader(self.test_dataset, self.test_size, False, num_workers=4, persistent_workers=True)

        # Best Test Results
        self.best = {'iteration': None,
                     'report': None,
                     'preds': None,
                     'labels': None,
                     'MaRIAs': None,
                     'loss': float("inf")}

    def write_log(self, line, train_step):
        self.summary.add_text('Log', line, train_step)
        self.summary.flush()

    def log_statistics(self, tag, loss, acc, f1, train_step):
        self.summary.add_scalar('Loss/' + tag, loss, train_step)
        self.summary.add_scalar('Accuracy/' + tag, acc, train_step)
        self.summary.add_scalar('F1 Score/' + tag, f1, train_step)
        self.summary.flush()

    def evaluate_on_test(self, network, train_step):

        all_binary_labels, all_preds, all_losses, all_y = [], [], [], []

        network.eval()
        for (x, y) in self.test_loader:

            x = x.to(device=self.device)
            binary_y = torch.where(y == 0, 0, 1).to(device=self.device)

            with torch.no_grad():
                out = network(x)
            preds = out.argmax(dim=1).float()

            loss = F.cross_entropy(out, binary_y)

            all_binary_labels.append(binary_y)
            all_preds.append(preds)
            all_losses += [loss] * len(y)
            all_y.append(y)

        all_binary_labels = torch.cat(all_binary_labels)
        all_preds = torch.cat(all_preds)
        all_losses = torch.stack(all_losses)
        all_y = torch.cat(all_y)

        # Convert back to cpu so can be converted to numpy for statistics
        all_preds = all_preds.cpu()
        all_binary_labels = all_binary_labels.cpu()

        test_avg_acc = (all_preds == all_binary_labels).float().mean()
        test_avg_loss = all_losses.mean()
        test_f1 = f1_score(all_binary_labels, all_preds, zero_division=0, average='weighted')
        test_report = report(all_binary_labels, all_preds)

        if test_avg_loss < self.best['loss']:

            self.best['iteration'] = train_step
            self.best['loss'] = test_avg_loss
            self.best['preds'] = all_preds
            self.best['labels'] = all_binary_labels
            self.best['MaRIAs'] = all_y
            self.best['report'] = test_report

            torch.save(network.state_dict(), self.model_save_path)
            print()
            print('===========================> Model saved!')
            print()

        print('Test statistics')
        print('Average Loss:       ', test_avg_loss)
        print('Prediction balance: ', all_preds.mean())
        print(test_report)
        print()

        self.log_statistics('test', test_avg_loss, test_avg_acc, test_f1, train_step)
        self.summary.flush()

    def train(self):

        train_step = 0

        input_channels = sum(self.input_features)
        print('Input channels: ', input_channels)
        network = PytorchResNet3D(self.feature_shape, self.attention, self.dropout_train_prob, in_chan=input_channels)

        if torch.cuda.device_count() > 1:
            print("Using ", torch.cuda.device_count(), " GPUs")
            network = torch.nn.DataParallel(network)

        network = network.to(device=self.device)
        optimiser = Adam(network.parameters(), lr=self.learning_rate)

        train_accuracies = []
        while train_step <= self.num_batches:

            for (x, y) in self.train_loader:
                network.train()

                x = x.to(device=self.device)
                binary_y = torch.where(y == 0, 0, 1).to(device=self.device)

                out = network(x)
                preds = out.argmax(dim=1).float()

                loss = F.cross_entropy(out, binary_y)

                optimiser.zero_grad()
                loss.backward()
                optimiser.step()

                # Convert back to cpu so can be converted to numpy for statistics
                # TODO: should I just do statistics manually?
                preds = preds.cpu()
                binary_y = binary_y.cpu()

                # Summaries and statistics
                print(f'-- Train Batch {train_step} --')
                print('Loss:               ', loss)
                print('Prediction balance: ', preds.mean())
                print(report(binary_y, preds))
                print()

                train_accuracies.append((preds == binary_y).float().mean())
                running_accuracy = torch.mean(torch.stack(train_accuracies[-self.test_evaluation_period:]))
                train_f1 = f1_score(binary_y, preds, zero_division=0, average='weighted')

                self.log_statistics('train', loss.item(), running_accuracy, train_f1, train_step)

                if train_step % self.test_evaluation_period == 0:
                    self.evaluate_on_test(network, train_step)

                train_step += 1

        print('Training finished!')
        print(self.best["report"])

        self.write_log(f'Best loss (iteration {self.best["iteration"]}): {self.best["loss"]}', train_step)
        self.write_log(f'with predictions: {self.best["preds"]}', train_step)
        self.write_log(f'of labels:        {self.best["labels"]}', train_step)
        self.write_log(f'with MaRIA scores:{self.best["MaRIAs"]}', train_step)
        self.write_log(self.best["report"], train_step)

        self.summary.close()
Exemplo n.º 12
0
class Logger():
    def __init__(self, log_path):
        self.log_path = log_path
        self.writer = None
        self.tracker = defaultdict(int)
        self.counter = defaultdict(int)
        self.mean = defaultdict(int)
        self.history = defaultdict(list)
        self.iterator = defaultdict(int)

    def safe(self, write):
        if write:
            self.writer = SummaryWriter(self.log_path)
        else:
            if self.writer is not None:
                self.writer.close()
                self.writer = None
            for name in self.mean:
                self.history[name].append(self.mean[name])
        return

    def reset(self):
        self.tracker = defaultdict(int)
        self.counter = defaultdict(int)
        self.mean = defaultdict(int)
        return

    def append(self, result, tag, n=1, mean=True):
        for k in result:
            name = '{}/{}'.format(tag, k)
            self.tracker[name] = result[k]
            self.counter[name] += n
            if mean:
                if isinstance(result[k], Number):
                    self.mean[name] = (
                        (self.counter[name] - n) * self.mean[name] +
                        n * result[k]) / self.counter[name]
                elif isinstance(result[k], Iterable):
                    if name not in self.mean:
                        self.mean[name] = [0 for _ in range(len(result[k]))]
                    for i in range(len(result[k])):
                        self.mean[name][i] = ((self.counter[name] - n) * self.mean[name][i] + n * result[k][i]) \
                                             / self.counter[name]
                else:
                    raise ValueError('Not valid data type')
        return

    def write(self, tag, metric_names):
        names = ['{}/{}'.format(tag, k) for k in metric_names]
        evaluation_info = []
        for name in names:
            tag, k = name.split('/')
            if isinstance(self.mean[name], Number):
                s = self.mean[name]
                evaluation_info.append('{}: {:.4f}'.format(k, s))
                if self.writer is not None:
                    self.iterator[name] += 1
                    self.writer.add_scalar(name, s, self.iterator[name])
            elif isinstance(self.mean[name], Iterable):
                s = tuple(self.mean[name])
                evaluation_info.append('{}: {}'.format(k, s))
                if self.writer is not None:
                    self.iterator[name] += 1
                    self.writer.add_scalar(name, s[0], self.iterator[name])
            else:
                raise ValueError('Not valid data type')
        info_name = '{}/info'.format(tag)
        info = self.tracker[info_name]
        info[2:2] = evaluation_info
        info = '  '.join(info)
        print(info)
        if self.writer is not None:
            self.iterator[info_name] += 1
            self.writer.add_text(info_name, info, self.iterator[info_name])
        return

    def flush(self):
        self.writer.flush()
        return
Exemplo n.º 13
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
        self.model = model
        self.args = args
        if data_collator is not None:
            self.data_collator = data_collator
        else:
            self.data_collator = DefaultDataCollator()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if not is_wandb_available():
            logger.info(
                "You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.args.local_rank in [-1, 0]:
            os.makedirs(self.args.output_dir, exist_ok=True)

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = (RandomSampler(self.train_dataset)
                         if self.args.local_rank == -1 else DistributedSampler(
                             self.train_dataset))
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_eval_dataloader(self,
                            eval_dataset: Optional[Dataset] = None
                            ) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        return DataLoader(
            eval_dataset if eval_dataset is not None else self.eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.args.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.args.learning_rate,
                          eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=num_training_steps)
        return optimizer, scheduler

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.
        """
        wandb.init(name=self.args.logging_dir, config=vars(self.args))
        # keep track of model topology and gradients
        wandb.watch(self.model)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        model.to(self.args.device)
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})
        if is_wandb_available():
            self._setup_wandb()

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps
                            learning_rate_scalar = scheduler.get_last_lr()[0]
                            logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                            if is_wandb_available():
                                wandb.log(logs, step=global_step)

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.model
                            else:
                                assert model is self.model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            self._rotate_checkpoints()
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))
                            logger.info(
                                "Saving optimizer and scheduler states to %s",
                                output_dir)

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)

    def _training_step(self, model: nn.Module, inputs: Dict[str, torch.Tensor],
                       optimizer: torch.optim.Optimizer) -> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        return self.args.local_rank == -1 or torch.distributed.get_rank() == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

        Will only save from the master process.
        """
        if self.is_world_master():
            self._save(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self,
                            checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
                            use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [
            str(x)
            for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
        ]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append(
                    (os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append(
                        (int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [
            checkpoint[1] for checkpoint in checkpoints_sorted
        ]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(
            0,
            len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:
                                                       number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(
                "Deleting older checkpoint [{}] due to args.save_total_limit".
                format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        prediction_loss_only: Optional[bool] = None,
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader,
                                       description="Evaluation")
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        # multi-gpu eval
        if self.args.n_gpu > 1 and not isinstance(self.model,
                                                  torch.nn.DataParallel):
            model = torch.nn.DataParallel(self.model)
        else:
            model = self.model
        model.to(self.args.device)

        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", len(dataloader.dataset))
        logger.info("  Batch size = %d", dataloader.batch_size)
        eval_losses: List[float] = []
        preds: np.ndarray = None
        label_ids: np.ndarray = None
        model.eval()

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach().cpu().numpy()
                else:
                    preds = np.append(preds,
                                      logits.detach().cpu().numpy(),
                                      axis=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach().cpu().numpy()
                    else:
                        label_ids = np.append(
                            label_ids,
                            inputs["labels"].detach().cpu().numpy(),
                            axis=0)

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["loss"] = np.mean(eval_losses)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)
Exemplo n.º 14
0
class Strategy(Logger):
    """
    A training Strategy describes the meaningful parts of a typical training loop.
    It also provides an optional logger.

    Args:
        log_dir (str): path to the logs directory
    """

    def __init__(self, log_dir):
        # self.name = name
        self.log_dir = log_dir
        self._logger = None
        self._log_metrics_cache = dict()
        self._optimizers = self._schedulers = None

    def _init_opt_sched(self):
        if self._optimizers is None or self._schedulers is None:
            self._optimizers, self._schedulers = self.opt_sched_unpack(self.optim_schedulers())

    @property
    def optimizers(self) -> List[Optimizer]:
        self._init_opt_sched()
        return self._optimizers

    @property
    def schedulers(self) -> List[_LRScheduler]:
        self._init_opt_sched()
        return self._schedulers

    @abstractmethod
    def optim_schedulers(self) -> OptimOrSched:
        """
        Creates the optimizers and schedulers

        Returns: [optimizer, ...], [scheduler, ...]
        """
        raise NotImplementedError

    @abstractmethod
    def tng_step(self, batch, batch_idx: int, optimizer_idx: int, epoch_idx: int, num_batches: int) -> dict:
        """
        Describe the training step. It should return a dict with at least the loss.

        Args:
            batch: data from a batch of the dataloader
            batch_idx: index of the the batch
            optimizer_idx:
            epoch_idx:
            num_batches:

        Returns (dict): it must at least contains the loss: {
            'loss': tng_loss,
            'acc': tng_acc,
        }
        """
        raise NotImplementedError

    # @abstractmethod
    def val_step(self, batch, batch_idx: int, epoch_idx: int, num_batches: int) -> dict:
        """
        Describe the validation step. It should return a dict with at least the loss.
        The dicts will be aggregated over steps and provided as list to `val_agg_outputs`.
        Logging here might cause performance issue if a step is quickly processed.

        Args:
            batch:
            batch_idx:
            epoch_idx:
            num_batches:

        Returns (dict): for example: {
            'loss': val_loss,
            'acc': val_acc,
            'gt': y,
            'logits': y_hat,
        }
        """
        pass  # raise NotImplementedError

    # @abstractmethod
    def val_agg_outputs(self, outputs: List[dict], agg_fn: AggFn, epoch_idx: int) -> dict:
        """
        This is where you have the opportunity to aggregate the outputs of the validation steps
        and log any metrics you wish.

        Args:
            outputs:
            agg_fn:
            epoch_idx:

        Returns:

        """
        pass  # raise NotImplementedError

    # @abstractmethod
    def tst_step(self, batch, batch_idx: int, num_batches: int) -> dict:
        """
        Describe the testing step. It should return a dict with at least the loss.
        The dicts will be aggregated over steps and provided as list to `tst_agg_outputs`.

        Args:
            batch:
            batch_idx:
            num_batches:

        Returns (dict): {
            'loss': test_loss,
            'acc': test_acc,
            'gt': y,
            'logits': y_hat,
        }
        """
        pass  # raise NotImplementedError

    # @abstractmethod
    def tst_agg_outputs(self, outputs: List[dict], agg_fn: AggFn) -> dict:
        """
        This is where you have the opportunity to aggregate the outputs of the testing steps
        and log any metrics you wish.


        Args:
            outputs:
            agg_fn:

        Returns:

        """
        pass  # raise NotImplementedError

    def add_graph(self) -> None:
        """
        [Optional] Log model(s) graph to tensorboard

        One can use `_add_graph` helper method
        """
        pass

    def load(self, path: Path):
        state_dicts = torch.load(path)

        for opt, state_dict in zip(self.optimizers, state_dicts['optimizers']):
            opt.load_state_dict(state_dict)

        for sched, state_dict in zip(self.schedulers, state_dicts['schedulers']):
            sched.load_state_dict(state_dict)

        for name in set(state_dicts.keys()) - {'optimizers', 'schedulers'}:
            module = getattr(self, name)
            module.load_state_dict(state_dicts[name])

    def save(self, path: Path):
        state_dicts = {name: module.state_dict() for name, module in self.modules}
        state_dicts.update({
            'optimizers': [opt.state_dict() for opt in self.optimizers],
            'schedulers': [sched.state_dict() for sched in self.schedulers],
        })
        torch.save(state_dicts, path)
        print(f'SAVED: {path}')

    @staticmethod
    def add_argz(parser: ArgumentParser) -> None:
        pass

    @property
    def modules(self) -> List[Tuple[str, nn.Module]]:
        return [(name, module) for name, module in self.__dict__.items() if isinstance(module, nn.Module)]

    @property
    def logger(self) -> SummaryWriter:
        """
        Provides a logger

        Returns:
        """
        if self._logger is None:
            self.set_default_logger()
        else:
            try:
                if isinstance(self._logger, TestTubeLogger):
                    return self._logger.experiment
            except ImportError as e:
                pass
        return self._logger

    @logger.setter
    def logger(self, logger):
        self._logger = logger

    def set_default_logger(self, exp_name: str = '', version: int = None):
        self.log_dir /= exp_name
        if version is None:
            version = 0
            log_dir = Path(self.log_dir) / f'version_{version}'
            while log_dir.exists():
                version += 1
                log_dir = Path(self.log_dir) / f'version_{version}'
        else:
            log_dir = Path(self.log_dir) / f'version_{version}'
        self.logger = SummaryWriter(log_dir / 'tf')
        return version

    def log_hyperparams(self, hparams):
        params = f'''##### Hyperparameters\n'''
        row_header = '''parameter|value\n-|-\n'''

        mkdown_log = ''.join([
            params,
            row_header,
            *[f'''{k}|{v}\n''' for k, v in hparams.items()],
        ])
        self.logger.add_text(
            tag='hparams',
            text_string=mkdown_log,
        )

    def log(self, metrics_dict: dict, global_step: int, interval: int = 1) -> None:
        """
        Logs a dictionary of scalars

        Args:
            metrics_dict:
            global_step:
            interval:

        """
        for name, scalar in metrics_dict.items():
            self._log_metrics_cache[name] = self._log_metrics_cache.get(name, []) + [scalar.item()]
        if global_step % interval == 0:
            metrics_dict = {name: torch.tensor(self._log_metrics_cache[name]).mean().item()
                            for name, _ in metrics_dict.items()}
            self._log_metrics_cache = dict()
            try:
                if isinstance(self._logger, TestTubeLogger):
                    self._logger.log_metrics(metrics_dict, step_num=global_step)
                    return
            except ImportError:
                pass
            for k, v in metrics_dict.items():
                self.logger.add_scalar(tag=k, scalar_value=v, global_step=global_step)

    # def _add_graph(self, model) -> None:
    #     try:
    #         x, _ = next(iter(self.tng_data_loader()))
    #         self.logger.add_graph(model, x)
    #     except Exception as e:
    #         warnings.warn("Failed to save model graph: {}".format(e))

    @staticmethod
    def opt_sched_unpack(opt_sched):
        try:
            opt, sched = opt_sched
        except TypeError:
            opt, sched = opt_sched, []
        if not isinstance(opt, Sequence):
            opt = [opt]
        if not isinstance(sched, Sequence):
            sched = [sched]
        return opt, sched
#train_data_path = "s3a://tubi-playground-production/smistry/emb3/train-aug-28-phase1"
train_data_path = "data/train-aug-28-phase1"

logging.basicConfig(filename="logs/" + model_alias + '.log',
                    filemode='w',
                    format='%(asctime)s - %(message)s',
                    level=logging.INFO)

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

original_train_data = pd.read_parquet(train_data_path)
logger.info("Data is loaded")
writer = SummaryWriter(
    log_dir='{}/{}'.format(tensorboard_base_dir, model_alias))
writer.add_text('alias', model_alias, 0)


def notify_loss_completion(epoch_id, batch_id, loss, net, model):
    #print("notify_loss_completion")
    writer.add_scalar("Batch/loss", loss, batch_id)
    logging.info('[Epoch {}] Batch {}, Loss {}'.format(epoch_id, batch_id,
                                                       loss))


def notify_batch_eval_completion(epoch_id, batch_id, loss, net, model):
    #print("notify_batch_eval_completion")
    pairs_ndcg = nn_pairs_ndcg_score(net)
    writer.add_scalar("Batch/pairs_ndcg", pairs_ndcg, batch_id)
    logging.info('[Epoch {}] Batch {}, Embs NDCG = {:.4f}'.format(
        epoch_id, batch_id, pairs_ndcg))
Exemplo n.º 16
0
def test_sac(args=get_args()):
    env, train_envs, test_envs = make_mujoco_env(args.task,
                                                 args.seed,
                                                 args.training_num,
                                                 args.test_num,
                                                 obs_norm=False)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low),
          np.max(env.action_space.high))
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # model
    net_a = Net(args.state_shape,
                hidden_sizes=args.hidden_sizes,
                device=args.device)
    actor = ActorProb(
        net_a,
        args.action_shape,
        max_action=args.max_action,
        device=args.device,
        unbounded=True,
        conditioned_sigma=True,
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    net_c1 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    net_c2 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    critic1 = Critic(net_c1, device=args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net_c2, device=args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = -np.prod(env.action_space.shape)
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        args.alpha = (target_entropy, log_alpha, alpha_optim)

    policy = SACPolicy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        tau=args.tau,
        gamma=args.gamma,
        alpha=args.alpha,
        estimation_step=args.n_step,
        action_space=env.action_space,
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs)
    train_collector.collect(n_step=args.start_timesteps, random=True)

    # log
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    args.algo_name = "sac"
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    # logger
    if args.logger == "wandb":
        logger = WandbLogger(
            save_interval=1,
            name=log_name.replace(os.path.sep, "__"),
            run_id=args.resume_id,
            config=args,
            project=args.wandb_project,
        )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    if args.logger == "tensorboard":
        logger = TensorboardLogger(writer)
    else:  # wandb
        logger.load(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

    if not args.watch:
        # trainer
        result = offpolicy_trainer(
            policy,
            train_collector,
            test_collector,
            args.epoch,
            args.step_per_epoch,
            args.step_per_collect,
            args.test_num,
            args.batch_size,
            save_best_fn=save_best_fn,
            logger=logger,
            update_per_step=args.update_per_step,
            test_in_train=False,
        )
        pprint.pprint(result)

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num,
                                    render=args.render)
    print(
        f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}'
    )
Exemplo n.º 17
0
class BaseTrainer:
    def __init__(self, dist, rank, config, resume, only_validation, model,
                 loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        model = DistributedDataParallel(model.to(rank), device_ids=[rank])
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustics"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        self.torch_stft = partial(stft,
                                  n_fft=n_fft,
                                  hop_length=hop_length,
                                  win_length=win_length)
        self.torch_istft = partial(istft,
                                   n_fft=n_fft,
                                   hop_length=hop_length,
                                   win_length=win_length)
        self.librosa_stft = partial(librosa.stft,
                                    n_fft=n_fft,
                                    hop_length=hop_length,
                                    win_length=win_length)
        self.librosa_istft = partial(librosa.istft,
                                     hop_length=hop_length,
                                     win_length=win_length)

        # Trainer.train in the config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config[
            "save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one."

        # Trainer.validation in the config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config[
            "validation_interval"]
        self.save_max_metric_score = self.validation_config[
            "save_max_metric_score"]
        assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one."

        # Trainer.visualization in the config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute(
        ) / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        # Debug validation, which skips training
        self.only_validation = only_validation

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir],
                              resume=resume)

            self.writer = SummaryWriter(self.logs_dir.as_posix(),
                                        max_queue=5,
                                        flush_secs=30)
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1)

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open(
                (self.save_dir /
                 f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(),
                    "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.

        Args:
            model_path (Path): The file path of the *.tar file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(
        ), f"The file {model_path.as_posix()} is not exist. please check path."

        model_checkpoint = torch.load(model_path.as_posix(),
                                      map_location="cpu")
        self.model.load_state_dict(model_checkpoint["model"], strict=False)
        self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """
        Resume the experiment from the latest checkpoint.
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute(
        ) / "latest_model.tar"
        assert latest_model_path.exists(
        ), f"{latest_model_path} does not exist, can not load latest checkpoint."

        # Make sure all processes (GPUs) do not start loading before the saving is finished.
        # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work
        self.dist.barrier()

        # Load it on the CPU and later use .to(device) on the model
        # Maybe slightly slow than use map_location="cuda:<...>"
        # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion
        checkpoint = torch.load(latest_model_path.as_posix(),
                                map_location="cpu")

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scaler.load_state_dict(checkpoint["scaler"])

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            self.model.module.load_state_dict(checkpoint["model"])
        else:
            self.model.load_state_dict(checkpoint["model"])

        # self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch."
            )

    def _save_checkpoint(self, epoch, is_best_epoch=False):
        """
        Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of:
            - epoch
            - best metric score in historical epochs
            - optimizer parameters
            - model parameters

        Args:
            is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True),
                                the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar".
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict(),
            "scaler": self.scaler.state_dict()
        }

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict["model"] = self.model.module.state_dict()
        else:
            state_dict["model"] = self.model.state_dict()

        # Saved in "latest_model.tar"
        # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc.
        # New checkpoint will overwrite the older one.
        torch.save(state_dict,
                   (self.checkpoints_dir / "latest_model.tar").as_posix())

        # "model_{epoch_number}.pth"
        # Contains only model.
        torch.save(state_dict["model"],
                   (self.checkpoints_dir /
                    f"model_{str(epoch).zfill(4)}.pth").as_posix())

        # If the model get a best metric score (means "is_best_epoch=True") in the current epoch,
        # the model checkpoint will be saved as "best_model.tar"
        # The newer best-scored checkpoint will overwrite the older one.
        if is_best_epoch:
            print(
                self.color_tool.red(
                    f"\t Found a best score in the {epoch} epoch, saving..."))
            torch.save(state_dict,
                       (self.checkpoints_dir / "best_model.tar").as_posix())

    def _is_best_epoch(self, score, save_max_metric_score=True):
        """
        Check if the current model got the best metric score
        """
        if save_max_metric_score and score >= self.best_score:
            self.best_score = score
            return True
        elif not save_max_metric_score and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _print_networks(models: list):
        print(
            f"This project contains {len(models)} models, the number of the parameters is: "
        )

        params_of_all_networks = 0
        for idx, model in enumerate(models, start=1):
            params_of_network = 0
            for param in model.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {idx}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(
            f"The amount of parameters in the project is {params_of_all_networks / 1e6} million."
        )

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self,
                                 noisy,
                                 enhanced,
                                 clean,
                                 name,
                                 epoch,
                                 mark=""):
        self.writer.add_audio(f"{mark}_Speech/{name}_Noisy",
                              noisy,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced",
                              enhanced,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Clean",
                              clean,
                              epoch,
                              sample_rate=16000)

        # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech
        noisy_mag, _ = librosa.magphase(
            self.librosa_stft(noisy, n_fft=320, hop_length=160,
                              win_length=320))
        enhanced_mag, _ = librosa.magphase(
            self.librosa_stft(enhanced,
                              n_fft=320,
                              hop_length=160,
                              win_length=320))
        clean_mag, _ = librosa.magphase(
            self.librosa_stft(clean, n_fft=320, hop_length=160,
                              win_length=320))
        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]):
            axes[k].set_title(f"mean: {np.mean(mag):.3f}, "
                              f"std: {np.std(mag):.3f}, "
                              f"max: {np.max(mag):.3f}, "
                              f"min: {np.min(mag):.3f}")
            librosa.display.specshow(librosa.amplitude_to_db(mag),
                                     cmap="magma",
                                     y_axis="linear",
                                     ax=axes[k],
                                     sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self,
                              noisy_list,
                              clean_list,
                              enhanced_list,
                              metrics_list,
                              epoch,
                              num_workers=10,
                              mark=""):
        """
        Get metrics on validation dataset by paralleling.

        Notes:
            1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are
             used for checking if the current epoch is a "best epoch."
            2. If you want to use a new metric, you must register it in "util.metrics" file.
        """
        assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence."

        # Check if the metric is registered in "util.metrics" file.
        for i in metrics_list:
            assert i in metrics.REGISTERED_METRICS.keys(
            ), f"{i} is not registered, please check 'util.metrics' file."

        stoi_mean = 0.0
        wb_pesq_mean = 0.0
        for metric_name in metrics_list:
            score_on_noisy = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, noisy_list))
            score_on_enhanced = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, enhanced_list))

            # Add the mean value of the metric to tensorboard
            mean_score_on_noisy = np.mean(score_on_noisy)
            mean_score_on_enhanced = np.mean(score_on_enhanced)
            self.writer.add_scalars(f"{mark}_Validation/{metric_name}", {
                "Noisy": mean_score_on_noisy,
                "Enhanced": mean_score_on_enhanced
            }, epoch)

            if metric_name == "STOI":
                stoi_mean = mean_score_on_enhanced

            if metric_name == "WB_PESQ":
                wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced)

        return (stoi_mean + wb_pesq_mean) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            if self.rank == 0:
                print(
                    self.color_tool.yellow(
                        f"{'=' * 15} {epoch} epoch {'=' * 15}"))
                print("[0 seconds] Begin training...")

            # [debug validation] Only run validation (only use the first GPU (process))
            # inference + calculating metrics + saving checkpoints
            if self.only_validation and self.rank == 0:
                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

                # Skip the following regular training, saving checkpoints, and validation
                continue

            # Regular training
            timer = ExecutionTime()
            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            #  Regular save checkpoints
            if self.rank == 0 and self.save_checkpoint_interval != 0 and (
                    epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            # Regular validation
            if self.rank == 0 and (epoch % self.validation_interval == 0):
                print(
                    f"[{timer.duration()} seconds] Training has finished, validation is in progress..."
                )

                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

            print(f"[{timer.duration()} seconds] This epoch is finished.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError
Exemplo n.º 18
0
def train(**kwargs):
    opt._parse(kwargs)
    train_writer = None
    value_writer = None
    if opt.vis:
        train_writer = SummaryWriter(
            log_dir='./runs/train_' +
            datetime.now().strftime('%y%m%d-%H-%M-%S'))
        value_writer = SummaryWriter(
            log_dir='./runs/val_' + datetime.now().strftime('%y%m%d-%H-%M-%S'))
    previous_loss = 1e10  # 上次学习的loss
    best_precision = 0  # 最好的精确度
    start_epoch = 0
    lr = opt.lr
    perf_scores_history = []  # 绩效分数
    # step1: criterion and optimizer
    # 1. 铰链损失(Hinge Loss):主要用于支持向量机(SVM) 中;
    # 2. 互熵损失 (Cross Entropy Loss,Softmax Loss ):用于Logistic 回归与Softmax 分类中;
    # 3. 平方损失(Square Loss):主要是最小二乘法(OLS)中;
    # 4. 指数损失(Exponential Loss) :主要用于Adaboost 集成学习算法中;
    # 5. 其他损失(如0-1损失,绝对值损失)
    criterion = t.nn.CrossEntropyLoss().to(opt.device)  # 损失函数
    # step2: meters
    train_losses = AverageMeter()  # 误差仪表
    train_top1 = AverageMeter()  # top1 仪表
    train_top5 = AverageMeter()  # top5 仪表
    pylogger = PythonLogger(msglogger)
    # step3: configure model
    model = getattr(models, opt.model)()  # 获得网络结构
    compression_scheduler = distiller.CompressionScheduler(model)
    optimizer = model.get_optimizer(lr, opt.weight_decay)  # 优化器
    if opt.load_model_path:
        # # 把所有的张量加载到CPU中
        # t.load(opt.load_model_path, map_location=lambda storage, loc: storage)
        # # 把所有的张量加载到GPU 1中
        # t.load(opt.load_model_path, map_location=lambda storage, loc: storage.cuda(1))
        # # 把张量从GPU 1 移动到 GPU 0
        # t.load(opt.load_model_path, map_location={'cuda:1': 'cuda:0'})
        checkpoint = t.load(opt.load_model_path)
        start_epoch = checkpoint["epoch"]
        # compression_scheduler.load_state_dict(checkpoint['compression_scheduler'], False)
        best_precision = checkpoint["best_precision"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer = checkpoint['optimizer']
    model.to(opt.device)  # 加载模型到 GPU

    if opt.compress:
        compression_scheduler = distiller.file_config(
            model, optimizer, opt.compress, compression_scheduler)  # 加载模型修剪计划表
        model.to(opt.device)
    # 学习速率调整器
    lr_scheduler = get_scheduler(optimizer, opt)
    # step4: data_image
    train_data = DatasetFromFilename(opt.data_root, flag='train')  # 训练集
    val_data = DatasetFromFilename(opt.data_root, flag='test')  # 验证集
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)  # 训练集加载器
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)  # 验证集加载器
    # train
    for epoch in range(start_epoch, opt.max_epoch):
        model.train()
        if opt.pruning:
            compression_scheduler.on_epoch_begin(epoch)  # epoch 开始修剪
        train_losses.reset()  # 重置仪表
        train_top1.reset()  # 重置仪表
        # print('训练数据集大小', len(train_dataloader))
        total_samples = len(train_dataloader.sampler)
        steps_per_epoch = math.ceil(total_samples / opt.batch_size)
        train_progressor = ProgressBar(mode="Train  ",
                                       epoch=epoch,
                                       total_epoch=opt.max_epoch,
                                       model_name=opt.model,
                                       lr=lr,
                                       total=len(train_dataloader))
        lr = lr_scheduler.get_lr()
        for ii, (data, labels, img_path, tag) in enumerate(train_dataloader):
            if not check_date(img_path, tag, msglogger): return
            if opt.pruning:
                compression_scheduler.on_minibatch_begin(
                    epoch, ii, steps_per_epoch, optimizer)  # batch 开始修剪
            train_progressor.current = ii + 1  # 训练集当前进度
            # train model
            input = data.to(opt.device)
            target = labels.to(opt.device)
            if train_writer:
                grid = make_grid(
                    (input.data.cpu() * 0.225 + 0.45).clamp(min=0, max=1))
                train_writer.add_image('train_images', grid,
                                       ii * (epoch + 1))  # 训练图片
            score = model(input)  # 网络结构返回值
            # 计算损失
            loss = criterion(score, target)
            if opt.pruning:
                # Before running the backward phase, we allow the scheduler to modify the loss
                # (e.g. add regularization loss)
                agg_loss = compression_scheduler.before_backward_pass(
                    epoch,
                    ii,
                    steps_per_epoch,
                    loss,
                    optimizer=optimizer,
                    return_loss_components=True)  # 模型修建误差
                loss = agg_loss.overall_loss
            train_losses.update(loss.item(), input.size(0))
            # loss = criterion(score[0], target)  # 计算损失   Inception3网络
            optimizer.zero_grad()  # 参数梯度设成0
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            if opt.pruning:
                compression_scheduler.on_minibatch_end(epoch, ii,
                                                       steps_per_epoch,
                                                       optimizer)  # batch 结束修剪

            precision1_train, precision5_train = accuracy(
                score, target, topk=(1, 5))  # top1 和 top5 的准确率

            # writer.add_graph(model, input)
            # precision1_train, precision2_train = accuracy(score[0], target, topk=(1, 2))  # Inception3网络
            train_losses.update(loss.item(), input.size(0))
            train_top1.update(precision1_train[0].item(), input.size(0))
            train_top5.update(precision5_train[0].item(), input.size(0))
            train_progressor.current_loss = train_losses.avg
            train_progressor.current_top1 = train_top1.avg
            train_progressor.current_top5 = train_top5.avg
            train_progressor()  # 打印进度
            if ii % opt.print_freq == 0:
                if train_writer:
                    train_writer.add_scalar('loss', train_losses.avg,
                                            ii * (epoch + 1))  # 训练误差
                    train_writer.add_text(
                        'top1', 'train accuracy top1 %s' % train_top1.avg,
                        ii * (epoch + 1))  # top1准确率文本
                    train_writer.add_scalars(
                        'accuracy', {
                            'top1': train_top1.avg,
                            'top5': train_top5.avg,
                            'loss': train_losses.avg
                        }, ii * (epoch + 1))
        # train_progressor.done()  # 保存训练结果为txt
        # validate and visualize
        if opt.pruning:
            distiller.log_weights_sparsity(model, epoch,
                                           loggers=[pylogger])  # 打印模型修剪结果
            compression_scheduler.on_epoch_end(epoch, optimizer)  # epoch 结束修剪
        val_loss, val_top1, val_top5 = val(model, criterion, val_dataloader,
                                           epoch, value_writer, lr)  # 校验模型
        sparsity = distiller.model_sparsity(model)
        perf_scores_history.append(
            distiller.MutableNamedTuple(
                {
                    'sparsity': sparsity,
                    'top1': val_top1,
                    'top5': val_top5,
                    'epoch': epoch + 1,
                    'lr': lr,
                    'loss': val_loss
                }, ))
        # 保持绩效分数历史记录从最好到最差的排序
        # 按稀疏度排序为主排序键,然后按top1、top5、epoch排序
        perf_scores_history.sort(key=operator.attrgetter(
            'sparsity', 'top1', 'top5', 'epoch'),
                                 reverse=True)
        for score in perf_scores_history[:1]:
            msglogger.info(
                '==> Best [Top1: %.3f   Top5: %.3f   Sparsity: %.2f on epoch: %d   Lr: %f   Loss: %f]',
                score.top1, score.top5, score.sparsity, score.epoch, lr,
                score.loss)

        best_precision = max(perf_scores_history[0].top1,
                             best_precision)  # 最大top1 准确率
        is_best = epoch + 1 == perf_scores_history[
            0].epoch  # 当前epoch 和最佳epoch 一样
        if is_best:
            model.save({
                "epoch":
                epoch + 1,
                "model_name":
                opt.model,
                "state_dict":
                model.state_dict(),
                "best_precision":
                best_precision,
                "optimizer":
                optimizer,
                "valid_loss": [val_loss, val_top1, val_top5],
                'compression_scheduler':
                compression_scheduler.state_dict(),
            })  # 保存模型
        # update learning rate
        lr_scheduler.step(epoch)  # 更新学习效率
        # 如果训练误差比上次大 降低学习效率
        # if train_losses.val > previous_loss:
        #     lr = lr * opt.lr_decay
        #     # 当loss大于上一次loss,降低学习率
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        #
        # previous_loss = train_losses.val
        t.cuda.empty_cache()  # 这个命令是清除没用的临时变量的
Exemplo n.º 19
0
def main(cfg):
    if cfg.SEED_VALUE >= 0:
        print(f'Seed value for the experiment {cfg.SEED_VALUE}')
        os.environ['PYTHONHASHSEED'] = str(cfg.SEED_VALUE)
        random.seed(cfg.SEED_VALUE)
        torch.manual_seed(cfg.SEED_VALUE)
        np.random.seed(cfg.SEED_VALUE)

    logger = create_logger(cfg.LOGDIR, phase='train')

    logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
    logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')

    logger.info(pprint.pformat(cfg))

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    writer = SummaryWriter(log_dir=cfg.LOGDIR)
    writer.add_text('config', pprint.pformat(cfg), 0)

    # ========= Dataloaders ========= #
    data_loaders = get_data_loaders(cfg)

    # ========= Compile Loss ========= #
    loss = TCMRLoss(
        e_loss_weight=cfg.LOSS.KP_2D_W,
        e_3d_loss_weight=cfg.LOSS.KP_3D_W,
        e_pose_loss_weight=cfg.LOSS.POSE_W,
        e_shape_loss_weight=cfg.LOSS.SHAPE_W,
        d_motion_loss_weight=cfg.LOSS.D_MOTION_LOSS_W,
    )

    # ========= Initialize networks, optimizers and lr_schedulers ========= #
    generator = TCMR(n_layers=cfg.MODEL.TGRU.NUM_LAYERS,
                     batch_size=cfg.TRAIN.BATCH_SIZE,
                     seqlen=cfg.DATASET.SEQLEN,
                     hidden_size=cfg.MODEL.TGRU.HIDDEN_SIZE,
                     pretrained=cfg.TRAIN.PRETRAINED_REGRESSOR).to(cfg.DEVICE)

    gen_optimizer = get_optimizer(
        model=generator,
        optim_type=cfg.TRAIN.GEN_OPTIM,
        lr=cfg.TRAIN.GEN_LR,
        weight_decay=cfg.TRAIN.GEN_WD,
        momentum=cfg.TRAIN.GEN_MOMENTUM,
    )

    motion_discriminator = MotionDiscriminator(
        rnn_size=cfg.TRAIN.MOT_DISCR.HIDDEN_SIZE,
        input_size=69,
        num_layers=cfg.TRAIN.MOT_DISCR.NUM_LAYERS,
        output_size=1,
        feature_pool=cfg.TRAIN.MOT_DISCR.FEATURE_POOL,
        attention_size=None if cfg.TRAIN.MOT_DISCR.FEATURE_POOL != 'attention'
        else cfg.TRAIN.MOT_DISCR.ATT.SIZE,
        attention_layers=None
        if cfg.TRAIN.MOT_DISCR.FEATURE_POOL != 'attention' else
        cfg.TRAIN.MOT_DISCR.ATT.LAYERS,
        attention_dropout=None
        if cfg.TRAIN.MOT_DISCR.FEATURE_POOL != 'attention' else
        cfg.TRAIN.MOT_DISCR.ATT.DROPOUT).to(cfg.DEVICE)

    dis_motion_optimizer = get_optimizer(model=motion_discriminator,
                                         optim_type=cfg.TRAIN.MOT_DISCR.OPTIM,
                                         lr=cfg.TRAIN.MOT_DISCR.LR,
                                         weight_decay=cfg.TRAIN.MOT_DISCR.WD,
                                         momentum=cfg.TRAIN.MOT_DISCR.MOMENTUM)

    motion_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        dis_motion_optimizer,
        mode='min',
        factor=0.1,
        patience=cfg.TRAIN.LR_PATIENCE,
        verbose=True,
    )

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        gen_optimizer,
        mode='min',
        factor=0.1,
        patience=cfg.TRAIN.LR_PATIENCE,
        verbose=True,
    )

    # ========= Start Training ========= #
    Trainer(
        data_loaders=data_loaders,
        generator=generator,
        motion_discriminator=motion_discriminator,
        criterion=loss,
        dis_motion_optimizer=dis_motion_optimizer,
        dis_motion_update_steps=cfg.TRAIN.MOT_DISCR.UPDATE_STEPS,
        gen_optimizer=gen_optimizer,
        start_epoch=cfg.TRAIN.START_EPOCH,
        end_epoch=cfg.TRAIN.END_EPOCH,
        device=cfg.DEVICE,
        writer=writer,
        debug=cfg.DEBUG,
        logdir=cfg.LOGDIR,
        lr_scheduler=lr_scheduler,
        motion_lr_scheduler=motion_lr_scheduler,
        resume=cfg.TRAIN.RESUME,
        num_iters_per_epoch=cfg.TRAIN.NUM_ITERS_PER_EPOCH,
        debug_freq=cfg.DEBUG_FREQ,
    ).fit()
Exemplo n.º 20
0
def train(model, train_loader, val_loader, epochs, save_iter=10, vis_iter=4,
          optimization_args=None, log_dir=None, args_to_log=None, metrics=None,
          callbacks=None, stopper=None, device_ids=None, num_accumulation_steps=1,
          grad_clip_norm=None):
    """ Trains the model. Validation loader can be None.
    Assumptions:
    1. loaders return (batch_inputs, batch_labels), where both can be lists or torch.Tensors
    2. models are inheriting from method_utils.Method.
    3. callback and metrics are inheriting from their abstract classes described in callbacks.py and metrics.py

    :param num_accumulation_steps: an integer that tells how many step gradients should be averaged before
                                   updating the parameters.
    """

    # print the architecture of the model, helps to notice mistakes
    print(model)

    # if there are at least two devices, we use distributed data training using torch.nn.DataParallel
    # note that PyTorch requires and we rely on the fact that the first device should match with model.device
    data_parallel_model = None
    if (device_ids is not None) and len(device_ids) >= 2:
        print(f"Using multiple GPUs: {device_ids}")
        data_parallel_model = torch.nn.DataParallel(model, device_ids=device_ids)

    # if log_dir is not given, logging will be done a new directory in 'logs/' directory
    if log_dir is None:
        log_root = 'logs/'
        utils.make_path(log_root)
        last_run = max([0] + [int(k) for k in os.listdir(log_root) if k.isdigit()])
        log_dir = os.path.join(log_root, '{0:04d}'.format(last_run + 1))
        utils.make_path(log_dir)

    tensorboard = SummaryWriter(log_dir)
    print("Visualize logs using: tensorboard --logdir={0}".format(log_dir))

    # add args_to_log to tensorboard, but also store it separately for easier access
    if args_to_log is not None:
        tensorboard.add_text('script arguments table', make_markdown_table_from_dict(vars(args_to_log)))
        with open(os.path.join(log_dir, 'args.pkl'), 'wb') as f:
            pickle.dump(args_to_log, f)

    optimizer = build_optimizer(model.named_parameters(), optimization_args)
    scheduler = build_scheduler(optimizer, optimization_args)

    # convert metrics to list
    if metrics is None:
        metrics = []
    assert isinstance(metrics, (list, tuple))

    # convert callbacks to list
    if callbacks is None:
        callbacks = []
    assert isinstance(callbacks, (list, tuple))

    for epoch in range(epochs):
        t0 = time.time()

        model.train()
        if data_parallel_model is not None:
            data_parallel_model.train()
        train_losses = run_partition(model=model, epoch=epoch, tensorboard=tensorboard, optimizer=optimizer,
                                     loader=train_loader, partition='train', training=True, metrics=metrics,
                                     data_parallel_model=data_parallel_model,
                                     num_accumulation_steps=num_accumulation_steps,
                                     grad_clip_norm=grad_clip_norm)

        val_losses = {}
        if val_loader is not None:
            model.eval()
            if data_parallel_model is not None:
                data_parallel_model.eval()
            val_losses = run_partition(model=model, epoch=epoch, tensorboard=tensorboard, optimizer=optimizer,
                                       loader=val_loader, partition='val', training=False, metrics=metrics,
                                       data_parallel_model=data_parallel_model,
                                       num_accumulation_steps=1,
                                       grad_clip_norm=grad_clip_norm)

        # log some statistics
        t = time.time()
        log_string = 'Epoch: {}/{}'.format(epoch, epochs)
        for k, v in list(train_losses.items()) + list(val_losses.items()):
            log_string += ', {}: {:0.6f}'.format(k, v)
        log_string += ', Time: {:0.1f}s'.format(t - t0)
        print(log_string)

        # add visualizations
        if (epoch + 1) % vis_iter == 0 and hasattr(model, 'visualize'):
            visualizations = model.visualize(train_loader, val_loader, tensorboard=tensorboard, epoch=epoch)
            # visualizations is a dictionary containing figures in (name, fig) format.
            # there are visualizations created using matplotlib rather than tensorboard
            for (name, fig) in visualizations.items():
                tensorboard.add_figure(name, fig, epoch)

        # save the model according to our schedule
        if (epoch + 1) % save_iter == 0:
            utils.save(model=model, optimizer=optimizer, scheduler=scheduler,
                       path=os.path.join(log_dir, 'checkpoints', 'epoch{}.mdl'.format(epoch)))

        # Call callbacks. These can be used to save the best model so far or initiate testing.
        for callback in callbacks:
            callback.call(epoch=epoch, model=model, optimizer=optimizer, scheduler=scheduler, log_dir=log_dir)

        # check whether the training should be ended
        if (stopper is not None) and stopper.call(epoch=epoch):
            print(f"Finishing the training at epoch {epoch}...")
            break

        # log the learning rate
        last_lr = scheduler.get_last_lr()
        if isinstance(last_lr, list):  # this happens when parameters are divided into groups
            last_lr = last_lr[0]
        tensorboard.add_scalar('hyper-parameters/lr', last_lr, epoch)

        # update the learning rate
        scheduler.step()

    # enable testing mode
    model.eval()

    # save the final version of the network
    utils.save(model=model, optimizer=optimizer, scheduler=scheduler,
               path=os.path.join(log_dir, 'checkpoints', 'final.mdl'))

    # do final visualizations
    if hasattr(model, 'visualize'):
        visualizations = model.visualize(train_loader, val_loader, tensorboard=tensorboard, epoch=epochs)
        for (name, fig) in visualizations.items():
            tensorboard.add_figure(name, fig, epochs)
            vis.savefig(fig, os.path.join(log_dir, name, 'final.png'))
Exemplo n.º 21
0
def test_discrete_crr(args=get_args()):
    # envs
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    test_envs = ShmemVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)
    # model
    feature_net = DQN(*args.state_shape,
                      args.action_shape,
                      device=args.device,
                      features_only=True).to(args.device)
    actor = Actor(feature_net,
                  args.action_shape,
                  device=args.device,
                  hidden_sizes=args.hidden_sizes,
                  softmax_output=False).to(args.device)
    critic = DQN(*args.state_shape, args.action_shape,
                 device=args.device).to(args.device)
    optim = torch.optim.Adam(list(actor.parameters()) +
                             list(critic.parameters()),
                             lr=args.lr)
    # define policy
    policy = DiscreteCRRPolicy(
        actor,
        critic,
        optim,
        args.gamma,
        policy_improvement_mode=args.policy_improvement_mode,
        ratio_upper_bound=args.ratio_upper_bound,
        beta=args.beta,
        min_q_weight=args.min_q_weight,
        target_update_freq=args.target_update_freq).to(args.device)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # buffer
    assert os.path.exists(args.load_buffer_name), \
        "Please run atari_qrdqn.py first to get expert's data buffer."
    if args.load_buffer_name.endswith('.pkl'):
        buffer = pickle.load(open(args.load_buffer_name, "rb"))
    elif args.load_buffer_name.endswith('.hdf5'):
        buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
    else:
        print(f"Unknown buffer format: {args.load_buffer_name}")
        exit(0)

    # collector
    test_collector = Collector(policy, test_envs, exploration_noise=True)

    # log
    log_path = os.path.join(
        args.logdir, args.task, 'crr',
        f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer, update_interval=args.log_interval)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return False

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        test_envs.seed(args.seed)
        print("Testing agent ...")
        test_collector.reset()
        result = test_collector.collect(n_episode=args.test_num,
                                        render=args.render)
        pprint.pprint(result)
        rew = result["rews"].mean()
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

    if args.watch:
        watch()
        exit(0)

    result = offline_trainer(policy,
                             buffer,
                             test_collector,
                             args.epoch,
                             args.update_per_epoch,
                             args.test_num,
                             args.batch_size,
                             stop_fn=stop_fn,
                             save_fn=save_fn,
                             logger=logger)

    pprint.pprint(result)
    watch()
Exemplo n.º 22
0
def test_ppo(args=get_args()):
    args.cfg_path = f"maps/{args.task}.cfg"
    args.wad_path = f"maps/{args.task}.wad"
    args.res = (args.skip_num, 84, 84)
    env = Env(args.cfg_path, args.frames_stack, args.res)
    args.state_shape = args.res
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    train_envs = ShmemVectorEnv([
        lambda: Env(args.cfg_path, args.frames_stack, args.res)
        for _ in range(args.training_num)
    ])
    test_envs = ShmemVectorEnv([
        lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
        for _ in range(min(os.cpu_count() - 1, args.test_num))
    ])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    net = DQN(*args.state_shape,
              args.action_shape,
              device=args.device,
              features_only=True,
              output_dim=args.hidden_size)
    actor = Actor(net,
                  args.action_shape,
                  device=args.device,
                  softmax_output=False)
    critic = Critic(net, device=args.device)
    optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(),
                             lr=args.lr)

    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

    # define policy
    def dist(p):
        return torch.distributions.Categorical(logits=p)

    policy = PPOPolicy(actor,
                       critic,
                       optim,
                       dist,
                       discount_factor=args.gamma,
                       gae_lambda=args.gae_lambda,
                       max_grad_norm=args.max_grad_norm,
                       vf_coef=args.vf_coef,
                       ent_coef=args.ent_coef,
                       reward_normalization=args.rew_norm,
                       action_scaling=False,
                       lr_scheduler=lr_scheduler,
                       action_space=env.action_space,
                       eps_clip=args.eps_clip,
                       value_clip=args.value_clip,
                       dual_clip=args.dual_clip,
                       advantage_normalization=args.norm_adv,
                       recompute_advantage=args.recompute_adv).to(args.device)
    if args.icm_lr_scale > 0:
        feature_net = DQN(*args.state_shape,
                          args.action_shape,
                          device=args.device,
                          features_only=True,
                          output_dim=args.hidden_size)
        action_dim = np.prod(args.action_shape)
        feature_dim = feature_net.output_dim
        icm_net = IntrinsicCuriosityModule(feature_net.net,
                                           feature_dim,
                                           action_dim,
                                           device=args.device)
        icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
        policy = ICMPolicy(policy, icm_net, icm_optim, args.icm_lr_scale,
                           args.icm_reward_scale,
                           args.icm_forward_loss_weight).to(args.device)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = VectorReplayBuffer(args.buffer_size,
                                buffer_num=len(train_envs),
                                ignore_obs_next=True,
                                save_only_last_obs=True,
                                stack_num=args.frames_stack)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # log
    log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
    log_path = os.path.join(args.logdir, args.task, log_name)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return mean_rewards >= 20
        else:
            return False

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        test_envs.seed(args.seed)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = VectorReplayBuffer(args.buffer_size,
                                        buffer_num=len(test_envs),
                                        ignore_obs_next=True,
                                        save_only_last_obs=True,
                                        stack_num=args.frames_stack)
            collector = Collector(policy,
                                  test_envs,
                                  buffer,
                                  exploration_noise=True)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            test_collector.reset()
            result = test_collector.collect(n_episode=args.test_num,
                                            render=args.render)
        rew = result["rews"].mean()
        lens = result["lens"].mean() * args.skip_num
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
        print(f'Mean length (over {result["n/ep"]} episodes): {lens}')

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # trainer
    result = onpolicy_trainer(policy,
                              train_collector,
                              test_collector,
                              args.epoch,
                              args.step_per_epoch,
                              args.repeat_per_collect,
                              args.test_num,
                              args.batch_size,
                              step_per_collect=args.step_per_collect,
                              stop_fn=stop_fn,
                              save_best_fn=save_best_fn,
                              logger=logger,
                              test_in_train=False)

    pprint.pprint(result)
    watch()
Exemplo n.º 23
0
    def logging_loop(self, num_gpus):
        """
        Keep track of the training performance.
        """
        # Launch the test worker to get performance metrics
        self.test_worker = self_play.SelfPlay.options(
            num_cpus=0,
            num_gpus=num_gpus,
        ).remote(
            self.checkpoint,
            self.Game,
            self.config,
            self.config.seed + self.config.num_workers,
        )
        self.test_worker.continuous_self_play.remote(
            self.shared_storage_worker, None, True
        )

        # Write everything in TensorBoard
        writer = SummaryWriter(self.config.results_path)

        print(
            "\nTraining...\nRun tensorboard --logdir ./results and go to http://localhost:6006/ to see in real time the training performance.\n"
        )

        # Save hyperparameters to TensorBoard
        hp_table = [
            f"| {key} | {value} |" for key, value in self.config.__dict__.items()
        ]
        writer.add_text(
            "Hyperparameters",
            "| Parameter | Value |\n|-------|-------|\n" + "\n".join(hp_table),
        )
        # Save model representation
        writer.add_text(
            "Model summary",
            self.summary,
        )
        # Loop for updating the training performance
        counter = 0
        keys = [
            "total_reward",
            "muzero_reward",
            "opponent_reward",
            "episode_length",
            "mean_value",
            "training_step",
            "lr",
            "total_loss",
            "value_loss",
            "reward_loss",
            "policy_loss",
            "num_played_games",
            "num_played_steps",
            "num_reanalysed_games",
        ]
        info = ray.get(self.shared_storage_worker.get_info.remote(keys))
        try:
            while info["training_step"] < self.config.training_steps:
                info = ray.get(
                    self.shared_storage_worker.get_info.remote(keys))
                writer.add_scalar(
                    "1.Total_reward/1.Total_reward",
                    info["total_reward"],
                    counter,
                )
                writer.add_scalar(
                    "1.Total_reward/2.Mean_value",
                    info["mean_value"],
                    counter,
                )
                writer.add_scalar(
                    "1.Total_reward/3.Episode_length",
                    info["episode_length"],
                    counter,
                )
                writer.add_scalar(
                    "1.Total_reward/4.MuZero_reward",
                    info["muzero_reward"],
                    counter,
                )
                writer.add_scalar(
                    "1.Total_reward/5.Opponent_reward",
                    info["opponent_reward"],
                    counter,
                )
                writer.add_scalar(
                    "2.Workers/1.Self_played_games",
                    info["num_played_games"],
                    counter,
                )
                writer.add_scalar(
                    "2.Workers/2.Training_steps", info["training_step"], counter
                )
                writer.add_scalar(
                    "2.Workers/3.Self_played_steps", info["num_played_steps"], counter
                )
                writer.add_scalar(
                    "2.Workers/4.Reanalysed_games",
                    info["num_reanalysed_games"],
                    counter,
                )
                writer.add_scalar(
                    "2.Workers/5.Training_steps_per_self_played_step_ratio",
                    info["training_step"] / max(1, info["num_played_steps"]),
                    counter,
                )
                writer.add_scalar("2.Workers/6.Learning_rate",
                                  info["lr"], counter)
                writer.add_scalar(
                    "3.Loss/1.Total_weighted_loss", info["total_loss"], counter
                )
                writer.add_scalar("3.Loss/Value_loss",
                                  info["value_loss"], counter)
                writer.add_scalar("3.Loss/Reward_loss",
                                  info["reward_loss"], counter)
                writer.add_scalar("3.Loss/Policy_loss",
                                  info["policy_loss"], counter)
                print(
                    f'Last test reward: {info["total_reward"]:.2f}. Training step: {info["training_step"]}/{self.config.training_steps}. Played games: {info["num_played_games"]}. Loss: {info["total_loss"]:.2f}',
                    end="\r",
                )
                counter += 1
                time.sleep(0.5)
        except KeyboardInterrupt:
            pass

        self.terminate_workers()

        if self.config.save_model:
            # Persist replay buffer to disk
            print("\n\nPersisting replay buffer games to disk...")
            pickle.dump(
                self.replay_buffer,
                open(os.path.join(self.config.results_path,
                                  "replay_buffer.pkl"), "wb"),
            )
Exemplo n.º 24
0
class AMCPruner(Pruner):
    """
    A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
    (https://arxiv.org/pdf/1802.03494.pdf)

    Parameters:
        model: nn.Module
            The model to be pruned.
        config_list: list
            Configuration list to configure layer pruning.
            Supported keys:
            - op_types: operation type to be pruned
            - op_names: operation name to be pruned
        evaluator: function
            function to evaluate the pruned model.
            The prototype of the function:
            >>> def evaluator(val_loader, model):
            >>>     ...
            >>>     return acc
        val_loader: torch.utils.data.DataLoader
            Data loader of validation dataset.
        suffix: str
            suffix to help you remember what experiment you ran. Default: None.
        job: str
            train_export: search best pruned model and export after search.
            export_only: export a searched model, searched_model_path and export_path must be specified.
        searched_model_path: str
            when job == export_only, use searched_model_path to specify the path of the searched model.
        export_path: str
            path for exporting models

        # parameters for pruning environment
        model_type: str
            model type to prune, currently 'mobilenet' and 'mobilenetv2' are supported. Default: mobilenet
        flops_ratio: float
            preserve flops ratio. Default: 0.5
        lbound: float
            minimum weight preserve ratio for each layer. Default: 0.2
        rbound: float
            maximum weight preserve ratio for each layer. Default: 1.0
        reward: function
            reward function type:
            - acc_reward: accuracy * 0.01
            - acc_flops_reward: - (100 - accuracy) * 0.01 * np.log(flops)
            Default: acc_reward
        # parameters for channel pruning
        n_calibration_batches: int
            number of batches to extract layer information. Default: 60
        n_points_per_layer: int
            number of feature points per layer. Default: 10
        channel_round: int
            round channel to multiple of channel_round. Default: 8

        # parameters for ddpg agent
        hidden1: int
            hidden num of first fully connect layer. Default: 300
        hidden2: int
            hidden num of second fully connect layer. Default: 300
        lr_c: float
            learning rate for critic. Default: 1e-3
        lr_a: float
            learning rate for actor. Default: 1e-4
        warmup: int
            number of episodes without training but only filling the replay memory. During warmup episodes,
            random actions ares used for pruning. Default: 100
        discount: float
            next Q value discount for deep Q value target. Default: 0.99
        bsize: int
            minibatch size for training DDPG agent. Default: 64
        rmsize: int
            memory size for each layer. Default: 100
        window_length: int
            replay buffer window length. Default: 1
        tau: float
            moving average for target network being used by soft_update. Default: 0.99
        # noise
        init_delta: float
            initial variance of truncated normal distribution
        delta_decay: float
            delta decay during exploration

        # parameters for training ddpg agent
        max_episode_length: int
            maximum episode length
        output_dir: str
            output directory to save log files and model files. Default: ./logs
        debug: boolean
            debug mode
        train_episode: int
            train iters each timestep. Default: 800
        epsilon: int
            linear decay of exploration policy. Default: 50000
        seed: int
            random seed to set for reproduce experiment. Default: None
    """

    def __init__(
            self,
            model,
            config_list,
            evaluator,
            val_loader,
            suffix=None,
            job='train_export',
            export_path=None,
            searched_model_path=None,
            model_type='mobilenet',
            dataset='cifar10',
            flops_ratio=0.5,
            lbound=0.2,
            rbound=1.,
            reward='acc_reward',
            n_calibration_batches=60,
            n_points_per_layer=10,
            channel_round=8,
            hidden1=300,
            hidden2=300,
            lr_c=1e-3,
            lr_a=1e-4,
            warmup=100,
            discount=1.,
            bsize=64,
            rmsize=100,
            window_length=1,
            tau=0.01,
            init_delta=0.5,
            delta_decay=0.99,
            max_episode_length=1e9,
            output_dir='./logs',
            debug=False,
            train_episode=800,
            epsilon=50000,
            seed=None):

        self.job = job
        self.searched_model_path = searched_model_path
        self.export_path = export_path

        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)

        checkpoint = deepcopy(model.state_dict())

        super().__init__(model, config_list, optimizer=None)

        # build folder and logs
        base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
        if suffix is not None:
            base_folder_name = base_folder_name + '_' + suffix
        self.output_dir = get_output_folder(output_dir, base_folder_name)

        if self.export_path is None:
            self.export_path = os.path.join(self.output_dir, '{}_r{}_exported.pth'.format(model_type, flops_ratio))

        self.env_args = Namespace(
            model_type=model_type,
            preserve_ratio=flops_ratio,
            lbound=lbound,
            rbound=rbound,
            reward=reward,
            n_calibration_batches=n_calibration_batches,
            n_points_per_layer=n_points_per_layer,
            channel_round=channel_round,
            output=self.output_dir
        )

        self.env = ChannelPruningEnv(
            self, evaluator, val_loader, checkpoint, args=self.env_args)

        if self.job == 'train_export':
            print('=> Saving logs to {}'.format(self.output_dir))
            self.tfwriter = SummaryWriter(log_dir=self.output_dir)
            self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
            print('=> Output path: {}...'.format(self.output_dir))

            nb_states = self.env.layer_embedding.shape[1]
            nb_actions = 1  # just 1 action here

            rmsize = rmsize * len(self.env.prunable_idx)  # for each layer
            print('** Actual replay buffer size: {}'.format(rmsize))

            self.ddpg_args = Namespace(
                hidden1=hidden1,
                hidden2=hidden2,
                lr_c=lr_c,
                lr_a=lr_a,
                warmup=warmup,
                discount=discount,
                bsize=bsize,
                rmsize=rmsize,
                window_length=window_length,
                tau=tau,
                init_delta=init_delta,
                delta_decay=delta_decay,
                max_episode_length=max_episode_length,
                debug=debug,
                train_episode=train_episode,
                epsilon=epsilon
            )
            self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)


    def compress(self):
        if self.job == 'train_export':
            self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
        self.export_pruned_model()

    def train(self, num_episode, agent, env, output_dir):
        agent.is_training = True
        step = episode = episode_steps = 0
        episode_reward = 0.
        observation = None
        T = []  # trajectory
        while episode < num_episode:  # counting based on episode
            # reset if it is the start of episode
            if observation is None:
                observation = deepcopy(env.reset())
                agent.reset(observation)

            # agent pick action ...
            if episode <= self.ddpg_args.warmup:
                action = agent.random_action()
                # action = sample_from_truncated_normal_distribution(lower=0., upper=1., mu=env.preserve_ratio, sigma=0.5)
            else:
                action = agent.select_action(observation, episode=episode)

            # env response with next_observation, reward, terminate_info
            observation2, reward, done, info = env.step(action)

            T.append([reward, deepcopy(observation), deepcopy(observation2), action, done])

            # fix-length, never reach here
            # if max_episode_length and episode_steps >= max_episode_length - 1:
            #     done = True

            # [optional] save intermideate model
            if num_episode / 3 <= 1 or episode % int(num_episode / 3) == 0:
                agent.save_model(output_dir)

            # update
            step += 1
            episode_steps += 1
            episode_reward += reward
            observation = deepcopy(observation2)

            if done:  # end of episode
                print(
                    '#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(
                        episode, episode_reward,
                        info['accuracy'],
                        info['compress_ratio']
                    )
                )
                self.text_writer.write(
                    '#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(
                        episode, episode_reward,
                        info['accuracy'],
                        info['compress_ratio']
                    )
                )
                final_reward = T[-1][0]
                # print('final_reward: {}'.format(final_reward))
                # agent observe and update policy
                for _, s_t, s_t1, a_t, done in T:
                    agent.observe(final_reward, s_t, s_t1, a_t, done)
                    if episode > self.ddpg_args.warmup:
                        agent.update_policy()

                #agent.memory.append(
                #    observation,
                #    agent.select_action(observation, episode=episode),
                #    0., False
                #)

                # reset
                observation = None
                episode_steps = 0
                episode_reward = 0.
                episode += 1
                T = []

                self.tfwriter.add_scalar('reward/last', final_reward, episode)
                self.tfwriter.add_scalar('reward/best', env.best_reward, episode)
                self.tfwriter.add_scalar('info/accuracy', info['accuracy'], episode)
                self.tfwriter.add_scalar('info/compress_ratio', info['compress_ratio'], episode)
                self.tfwriter.add_text('info/best_policy', str(env.best_strategy), episode)
                # record the preserve rate for each layer
                for i, preserve_rate in enumerate(env.strategy):
                    self.tfwriter.add_scalar('preserve_rate/{}'.format(i), preserve_rate, episode)

                self.text_writer.write('best reward: {}\n'.format(env.best_reward))
                self.text_writer.write('best policy: {}\n'.format(env.best_strategy))
        self.text_writer.close()

    def export_pruned_model(self):
        if self.searched_model_path is None:
            wrapper_model_ckpt = os.path.join(self.output_dir, 'best_wrapped_model.pth')
        else:
            wrapper_model_ckpt = self.searched_model_path
        self.env.reset()
        self.bound_model.load_state_dict(torch.load(wrapper_model_ckpt))

        print('validate searched model:', self.env._validate(self.env._val_loader, self.env.model))
        self.env.export_model()
        self._unwrap_model()
        print('validate exported model:', self.env._validate(self.env._val_loader, self.env.model))

        torch.save(self.bound_model, self.export_path)
        print('exported model saved to: {}'.format(self.export_path))
Exemplo n.º 25
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for 🤗 Transformers.

    Args:
        model (:class:`~transformers.PreTrainedModel`):
            The model to train, evaluate or use for predictions.
        args (:class:`~transformers.TrainingArguments`):
            The arguments to tweak training.
        data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
            The function to use to from a batch from a list of elements of :obj:`train_dataset` or
            :obj:`eval_dataset`.
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
            The dataset to use for training.
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
            The dataset to use for evaluation.
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
            When performing evaluation and predictions, only returns the loss.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
    optimizers: Tuple[torch.optim.Optimizer,
                      torch.optim.lr_scheduler.LambdaLR] = None
    global_step: Optional[int] = None
    epoch: Optional[float] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = None,
    ):
        self.model = model.to(args.device)
        self.args = args
        self.data_collator = data_collator if data_collator is not None else default_data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
        if not callable(self.data_collator) and callable(
                getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                ("The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                 +
                 "with a `collate_batch` are deprecated and won't be supported in a future version."
                 ),
                FutureWarning,
            )

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.
        """
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            train_sampler = None
        elif self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        elif is_torch_tpu_available():
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (RandomSampler(self.train_dataset)
                             if self.args.local_rank == -1 else
                             DistributedSampler(self.train_dataset))
        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_eval_dataloader(self,
                            eval_dataset: Optional[Dataset] = None
                            ) -> DataLoader:
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                If provided, will override `self.eval_dataset`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            sampler = None
        elif is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                eval_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)

        data_loader = DataLoader(
            eval_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

        Args:
            test_dataset (obj:`Dataset`): The test dataset to use.
        """
        # We use the same batch_size as for eval.
        if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
            sampler = None
        elif is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)

        data_loader = DataLoader(
            test_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )
        return data_loader

    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.args.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.args.learning_rate,
                          eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=num_training_steps)
        return optimizer, scheduler

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        if self.is_world_master():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"),
                       config=vars(self.args))
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv(
                    "WANDB_WATCH") != "false":
                wandb.watch(self.model,
                            log=os.getenv("WANDB_WATCH", "gradients"),
                            log_freq=max(100, self.args.logging_steps))

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its Dataset.
        """
        return len(dataloader.dataset)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self._log(logs)

                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def _log(self,
             logs: Dict[str, float],
             iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            logger.info(output)

    def _training_step(self, model: nn.Module,
                       inputs: Dict[str, Union[torch.Tensor, Any]],
                       optimizer: torch.optim.Optimizer) -> float:
        model.train()
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)

        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
        # Our model outputs do not work with DataParallel, so forcing return tuple.
        if isinstance(model, nn.DataParallel):
            inputs["return_tuple"] = True

        outputs = model(**inputs)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_local_master(self) -> bool:
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank(
            ) == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Will save the model, so you can reload it using :obj:`from_pretrained()`.

        Will only save from the world_master process (unless in TPUs).
        """

        if is_torch_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_master():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir,
                                               "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self,
                            checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
                            use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [
            str(x)
            for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
        ]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append(
                    (os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append(
                        (int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [
            checkpoint[1] for checkpoint in checkpoints_sorted
        ]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(
            0,
            len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:
                                                       number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(
                "Deleting older checkpoint [{}] due to args.save_total_limit".
                format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(self,
                 eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.
        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader,
                                       description="Evaluation")

        self._log(output.metrics)

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and returns predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
                Dataset to run the predictions on.
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            past = None

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "lm_labels", "masked_lm_labels"])

            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.to(self.args.device)
            if self.args.past_index >= 0:
                inputs["mems"] = past
            # Our model outputs do not work with DataParallel, so forcing return tuple.
            if isinstance(model, nn.DataParallel):
                inputs["return_tuple"] = True

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]
                if self.args.past_index >= 0:
                    past = outputs[self.args.past_index if has_labels else self
                                   .args.past_index - 1]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach()
                else:
                    preds = torch.cat((preds, logits.detach()), dim=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                    else:
                        label_ids = torch.cat(
                            (label_ids, inputs["labels"].detach()), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor,
                           num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [
            tensor.clone() for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output
Exemplo n.º 26
0
def test_psrl(args=get_args()):
    env = gym.make(args.task)
    if args.task == "NChain-v0":
        env.spec.reward_threshold = 3647  # described in PSRL paper
    print("reward threshold:", env.spec.reward_threshold)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # train_envs = gym.make(args.task)
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    n_action = args.action_shape
    n_state = args.state_shape
    trans_count_prior = np.ones((n_state, n_action, n_state))
    rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior)
    rew_std_prior = np.full((n_state, n_action), args.rew_std_prior)
    policy = PSRLPolicy(trans_count_prior, rew_mean_prior, rew_std_prior,
                        args.gamma, args.eps, args.add_done_loop)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                VectorReplayBuffer(args.buffer_size,
                                                   len(train_envs)),
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs)
    # log
    log_path = os.path.join(args.logdir, args.task, 'psrl')
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))

    # logger = BasicLogger(writer)

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        else:
            return False

    train_collector.collect(n_step=args.buffer_size, random=True)
    # trainer, test it without logger
    result = onpolicy_trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        1,
        args.test_num,
        0,
        episode_per_collect=args.episode_per_collect,
        stop_fn=stop_fn,
        # logger=logger,
        test_in_train=False)

    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        policy.eval()
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=args.test_num,
                                        render=args.render)
        rews, lens = result["rews"], result["lens"]
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
    elif env.spec.reward_threshold:
        assert result["best_reward"] >= env.spec.reward_threshold
class OrtTrainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
        self.model = model
        self.args = args
        if data_collator is not None:
            self.data_collator = data_collator
        else:
            self.data_collator = DefaultDataCollator()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only

        if is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        set_seed(self.args.seed)
        onnxruntime.set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)

        torch.cuda.set_device(self.args.local_rank)

        self.ort_model = self.to_ort_model(model, model.config, args)

    def update_torch_model(self, ):
        if self.ort_model:
            logger.info("Updating weights of torch model from ORT model.")
            ort_state_dict = self.ort_model.state_dict()
            self.model.load_state_dict(ort_state_dict, strict=False)
        else:
            logger.warning(
                "No ORT model found to update weights from, assuming torch model is up to date."
            )

    def gpt2_model_description(self, n_head, vocab_size, n_hidden, n_layer,
                               n_ctx, batch_size):

        logger.info("****num of head is: {}".format(n_head))
        logger.info("****vocab size is: {}".format(vocab_size))
        logger.info("****num of hidden layer is: {}".format(n_hidden))
        logger.info("****num of layer is: {}".format(n_layer))
        logger.info("****seq length is: {}".format(n_ctx))

        input_ids_desc = IODescription('input_ids', [batch_size, n_ctx],
                                       torch.int64,
                                       num_classes=vocab_size)
        labels_desc = IODescription('labels', [batch_size, n_ctx],
                                    torch.int64,
                                    num_classes=vocab_size)

        loss_desc = IODescription('loss', [], torch.float32)

        return ModelDescription([input_ids_desc, labels_desc], [loss_desc])

    def ort_trainer_learning_rate_description(self):
        return IODescription('Learning_Rate', [
            1,
        ], torch.float32)

    def to_ort_model(self, model, config, args):
        model_desc = self.gpt2_model_description(config.n_head,
                                                 config.vocab_size,
                                                 config.n_embd, config.n_layer,
                                                 config.n_ctx,
                                                 args.per_gpu_train_batch_size)
        learning_rate_description = self.ort_trainer_learning_rate_description(
        )

        def map_optimizer_attributes(name):
            no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
            no_decay = False
            for no_decay_key in no_decay_keys:
                if no_decay_key in name:
                    no_decay = True
                    break
            if no_decay:
                return {
                    "alpha": 0.9,
                    "beta": 0.999,
                    "lambda": 0.0,
                    "epsilon": args.adam_epsilon
                }
            else:
                return {
                    "alpha": 0.9,
                    "beta": 0.999,
                    "lambda": args.weight_decay,
                    "epsilon": args.adam_epsilon
                }

        from onnxruntime.capi._pybind_state import set_cuda_device_id, set_arena_extend_strategy, ArenaExtendStrategy
        set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested)
        set_cuda_device_id(self.args.local_rank)

        model = ORTTrainer(
            model,
            None,
            model_desc,
            "AdamOptimizer",
            map_optimizer_attributes,
            learning_rate_description,
            args.device,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            world_rank=self.args.world_rank,
            world_size=self.args.world_size,
            use_mixed_precision=self.args.fp16,
            allreduce_post_accumulation=True,
            _opset_version=12)

        logger.info("****************************Model converted to ORT")
        return model

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = (RandomSampler(self.train_dataset)
                         if self.args.local_rank == -1 else DistributedSampler(
                             self.train_dataset))
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.per_gpu_train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_eval_dataloader(self,
                            eval_dataset: Optional[Dataset] = None
                            ) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        return DataLoader(
            eval_dataset if eval_dataset is not None else self.eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        scheduler = linear_schedule_with_warmup(
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=t_total)

        loss_scaler = LossScaler(
            self.ort_model.loss_scale_input_name,
            True,
            up_scale_window=2000,
            loss_scale=float(1 << 20)) if self.args.fp16 else 1

        model = self.ort_model

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())

        # Train!
        if self.is_world_master():
            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_dataloader.dataset))
            logger.info("  Num Epochs = %d", num_train_epochs)
            logger.info("  Instantaneous batch size per GPU = %d",
                        self.args.per_gpu_train_batch_size)
            logger.info(
                "  Total train batch size (w. parallel, distributed & accumulation) = %d",
                self.args.train_batch_size *
                self.args.gradient_accumulation_steps *
                (self.args.world_size if self.args.local_rank != -1 else 1),
            )
            logger.info("  Gradient Accumulation steps = %d",
                        self.args.gradient_accumulation_steps)
            logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        global_batch_train_start = time.time()

        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if len(inputs['input_ids']
                       ) < self.args.per_gpu_train_batch_size:
                    #skip incomplete batch
                    logger.info('Skipping incomplete batch...')
                    continue

                learning_rate = torch.tensor([
                    scheduler.get_lr_this_step(global_step,
                                               base_lr=self.args.learning_rate)
                ])
                loss, all_finite = self._training_step(model, inputs,
                                                       learning_rate,
                                                       loss_scaler)
                tr_loss += loss
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    if self.args.fp16:
                        loss_scaler.update_loss_scale(all_finite.item())

                    global_step += 1
                    global_batch_train_duration = time.time(
                    ) - global_batch_train_start
                    global_batch_train_start = time.time()

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            loss_avg = (tr_loss - logging_loss) / (
                                self.args.logging_steps *
                                self.args.gradient_accumulation_steps)
                            logs["learning_rate"] = learning_rate.item()
                            logs["loss"] = loss_avg
                            logs["global_step"] = global_step
                            logs[
                                "global_step_time"] = global_batch_train_duration
                            logging_loss = tr_loss

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                                    run.log(k, v)
                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.ort_model
                            else:
                                assert model is self.ort_model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            # self._rotate_checkpoints()

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()
        self.update_torch_model()
        del (self.ort_model)
        self.ort_model = None

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)

    def _training_step(self, model: nn.Module, inputs: Dict[str, torch.Tensor],
                       learning_rate, loss_scaler) -> float:
        model.train()

        if self.args.fp16:
            loss_scale = torch.tensor([loss_scaler.loss_scale_])
            result = model(inputs['input_ids'], inputs['labels'],
                           learning_rate, loss_scale)
        else:
            result = model(inputs['input_ids'], inputs['labels'],
                           learning_rate)

        all_finite = None
        if isinstance(result, (list, tuple)):
            loss = result[0]
            all_finite = result[-1]
        else:
            loss = result

        return loss.item(), all_finite

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        return self.args.local_rank == -1 or torch.distributed.get_rank() == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

        Will only save from the master process.
        """
        if self.is_world_master():
            self._save(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)

        self.update_torch_model()
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self,
                            checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
                            use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [
            str(x)
            for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
        ]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append(
                    (os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append(
                        (int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [
            checkpoint[1] for checkpoint in checkpoints_sorted
        ]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(
            0,
            len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:
                                                       number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(
                "Deleting older checkpoint [{}] due to args.save_total_limit".
                format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
            self,
            eval_dataset: Optional[Dataset] = None,
            prediction_loss_only: Optional[bool] = None) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader,
                                       description="Evaluation")
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
        self.update_torch_model()
        # multi-gpu eval
        if self.args.n_gpu > 1 and not isinstance(self.model,
                                                  torch.nn.DataParallel):
            model = torch.nn.DataParallel(self.model)
        else:
            model = self.model

        model.to(self.args.device)
        if self.is_world_master():
            logger.info("***** Running %s *****", description)
            logger.info("  Num examples = %d", len(dataloader.dataset))
            logger.info("  Batch size = %d", dataloader.batch_size)
        eval_losses: List[float] = []
        preds: np.ndarray = None
        label_ids: np.ndarray = None
        model.eval()

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach().cpu().numpy()
                else:
                    preds = np.append(preds,
                                      logits.detach().cpu().numpy(),
                                      axis=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach().cpu().numpy()
                    else:
                        label_ids = np.append(
                            label_ids,
                            inputs["labels"].detach().cpu().numpy(),
                            axis=0)

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["loss"] = np.mean(eval_losses)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)
Exemplo n.º 28
0
class TrainingProcessHandler(object):
    def __init__(self,
                 data_folder="logs",
                 model_folder="model",
                 enable_iteration_progress_bar=False,
                 model_save_key="loss",
                 mlflow_tags=None,
                 mlflow_parameters=None,
                 enable_mlflow=True,
                 mlflow_experiment_name="undeepvo"):
        if mlflow_tags is None:
            mlflow_tags = {}
        if mlflow_parameters is None:
            mlflow_parameters = {}
        self._name = None
        self._epoch_count = 0
        self._iteration_count = 0
        self._current_epoch = 0
        self._current_iteration = 0
        self._log_folder = data_folder
        self._iteration_progress_bar = None
        self._enable_iteration_progress_bar = enable_iteration_progress_bar
        self._epoch_progress_bar = None
        self._writer = None
        self._train_metrics = {}
        self._model = None
        self._model_folder = model_folder
        self._run_name = ""
        self.train_history = {}
        self.validation_history = {}
        self._model_save_key = model_save_key
        self._previous_model_save_metric = None
        if not os.path.exists(self._model_folder):
            os.mkdir(self._model_folder)
        if not os.path.exists(self._log_folder):
            os.mkdir(self._log_folder)
        self._audio_configs = {}
        self._global_epoch_step = 0
        self._global_iteration_step = 0
        if enable_mlflow:
            self._mlflow_handler = MlFlowHandler(
                experiment_name=mlflow_experiment_name,
                mlflow_tags=mlflow_tags,
                mlflow_parameters=mlflow_parameters)
        else:
            self._mlflow_handler = None
        self._artifacts = []

    def setup_handler(self, name, model):
        self._name = name
        self._run_name = name + "_" + datetime.datetime.now().strftime(
            '%Y-%m-%d-%H-%M-%S')
        self._writer = SummaryWriter(
            os.path.join(self._log_folder, self._run_name))
        self._model = model
        self._previous_model_save_metric = None
        self.train_history = {}
        self.validation_history = {}
        self._global_epoch_step = 0
        self._global_iteration_step = 0

    def start_callback(self, epoch_count, iteration_count, parameters=None):
        if parameters is None:
            parameters = {}
        self._epoch_count = epoch_count
        self._iteration_count = iteration_count
        self._current_epoch = 0
        self._current_iteration = 0
        self._epoch_progress_bar = tqdm(total=self._epoch_count)
        if self._enable_iteration_progress_bar:
            self._iteration_progress_bar = tqdm(total=self._iteration_count //
                                                self._epoch_count)
        if self._mlflow_handler is not None:
            self._mlflow_handler.start_callback(parameters)

    def epoch_callback(self,
                       metrics,
                       image_batches=None,
                       figures=None,
                       audios=None,
                       texts=None):
        self._artifacts = []
        for key, value in metrics.items():
            self.validation_history.setdefault(key, []).append(value)
        self._write_epoch_metrics(metrics)
        if image_batches is not None:
            self._write_image_batches(image_batches)
        if figures is not None:
            self._write_figures(figures)
        if audios is not None:
            self._write_audios(audios)
        if texts is not None:
            self._write_texts(texts)
        if self._enable_iteration_progress_bar and self._epoch_count != self._current_epoch - 1:
            self._iteration_progress_bar.reset()
        self._epoch_progress_bar.update()
        self._epoch_progress_bar.set_postfix_str(
            self.metric_string("valid", metrics))
        if self.should_save_model(metrics) and self._model is not None:
            torch.save(
                self._model.state_dict(),
                os.path.join(self._model_folder,
                             f"{self._run_name}_checkpoint.pth"))
        self._current_epoch += 1
        self._global_epoch_step += 1
        if self._mlflow_handler is not None:
            self._mlflow_handler.epoch_callback(metrics, self._current_epoch,
                                                self._artifacts)

    def iteration_callback(self, metrics):
        for key, value in metrics.items():
            self.train_history.setdefault(key, []).append(value)
        self._train_metrics = metrics
        self._write_iteration_metrics(metrics)
        if self._enable_iteration_progress_bar:
            self._iteration_progress_bar.set_postfix_str(
                self.metric_string("train", metrics))
            self._iteration_progress_bar.update()
        self._current_iteration += 1
        self._global_iteration_step += 1

    def finish_callback(self, metrics):
        print(self.metric_string("test", metrics))
        self._writer.close()
        if self._enable_iteration_progress_bar:
            self._iteration_progress_bar.close()
        self._epoch_progress_bar.close()
        if self._mlflow_handler is not None:
            self._mlflow_handler.finish_callback()

    @staticmethod
    def metric_string(prefix, metrics):
        result = ""
        for key, value in metrics.items():
            result += "{} {} = {:>3.3f}, ".format(prefix, key, value)
        return result[:-2]

    def _write_epoch_metrics(self, validation_metrics):
        for key, value in validation_metrics.items():
            self._writer.add_scalar(f"epoch/{key}",
                                    value,
                                    global_step=self._global_epoch_step)

    def _write_iteration_metrics(self, train_metrics):
        for key, value in train_metrics.items():
            self._writer.add_scalar(f"iteration/{key}",
                                    value,
                                    global_step=self._global_iteration_step)

    def should_save_model(self, metrics):
        if self._model_save_key not in metrics.keys():
            return True
        if self._previous_model_save_metric is None:
            self._previous_model_save_metric = metrics[self._model_save_key]
            return True
        if self._previous_model_save_metric > metrics[self._model_save_key]:
            self._previous_model_save_metric = metrics[self._model_save_key]
            return True
        return False

    def _write_image_batches(self, image_batches):
        for key, value in image_batches.items():
            self._writer.add_images(key,
                                    value,
                                    self._global_epoch_step,
                                    dataformats="NHWC")

    def _write_figures(self, figures):
        for key, value in figures.items():
            self._writer.add_figure(key, value, self._global_epoch_step)
            artifact_name = f"{self._log_folder}/{key}_{self._global_epoch_step:04d}.png"
            value.savefig(artifact_name)
            self._artifacts.append(artifact_name)

    def _write_audios(self, audios):
        for key, value in audios.items():
            self._writer.add_audio(key, value, self._global_epoch_step,
                                   **self._audio_configs)

    def set_audio_configs(self, configs):
        self._audio_configs = configs

    def _write_texts(self, texts):
        for key, value in texts.items():
            self._writer.add_text(key, value, self._global_epoch_step)
Exemplo n.º 29
0
            self.raw_rewards[i] += [infos[i]["raw_rewards"]] 
        newinfos = list(infos[:])
        for i in range(len(dones)):
            if dones[i]:
                info = infos[i].copy()
                raw_rewards = np.array(self.raw_rewards[i]).sum(0)
                raw_names = [str(rf) for rf in self.rfs]
                info['microrts_stats'] = dict(zip(raw_names, raw_rewards))
                self.raw_rewards[i] = []
                newinfos[i] = info
        return obs, rews, dones, newinfos

# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:
    import wandb
    run = wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
envs = MicroRTSGridModeVecEnv(
    num_selfplay_envs=args.num_selfplay_envs,
    num_bot_envs=args.num_bot_envs,
    max_steps=2000,
Exemplo n.º 30
0
def main():

    from config import config_enhanced
    writer = SummaryWriter(os.path.join('runs', name_dir(config_enhanced)))

    torch.multiprocessing.freeze_support()

    print("Current config_enhanced is:")
    pprint(config_enhanced)
    writer.add_text("config", str(config_enhanced))

    save_path = str(writer.get_logdir())
    try:
        os.makedirs(save_path)
    except OSError:
        pass

    # with open(os.path.join(save_path, "config.json"), 'w') as outfile:
    #     json.dump(config_enhanced, outfile)

    torch.manual_seed(config_enhanced['seed'])
    torch.cuda.manual_seed_all(config_enhanced['seed'])

    use_cuda = torch.cuda.is_available()
    if torch.cuda.is_available() and config_enhanced['cuda_deterministic']:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # torch.set_num_threads(1)
    if use_cuda:
        device = torch.device('cuda')
        print("using GPU")
    else:
        device = torch.device('cpu')
        print("using CPU")

    if config_enhanced['num_processes'] == "num_cpu":
        num_processes = multiprocessing.cpu_count() - 1
    else:
        num_processes = config_enhanced['num_processes']

    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    #     model = torch.nn.DataParallel(model)

    env = CholeskyTaskGraph(**config_enhanced['env_settings'])
    envs = VectorEnv(env, num_processes)
    envs.reset()

    model = SimpleNet(**config_enhanced["network_parameters"])
    if config_enhanced["model_path"]:
        model.load_state_dict(torch.load(config_enhanced['model_path']))

    actor_critic = Policy(model, envs.action_space, config_enhanced)
    actor_critic = actor_critic.to(device)

    if config_enhanced['agent'] == 'PPO':
        print("using PPO")
        agent_settings = config_enhanced['PPO_settings']
        agent = PPO(
            actor_critic,
            **agent_settings)

    elif config_enhanced['agent'] == 'A2C':
        print("using A2C")
        agent_settings = config_enhanced['A2C_settings']
        agent = A2C_ACKTR(
            actor_critic,
            **agent_settings)

    rollouts = RolloutStorage(config_enhanced['trajectory_length'], num_processes,
                              env_example.observation_space.shape, env_example.action_space)



    obs = envs.reset()
    obs = torch.tensor(obs, device=device)
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        config_enhanced['num_env_steps']) // config_enhanced['trajectory_length'] // num_processes
    for j in range(num_updates):

        if config_enhanced['use_linear_lr_decay']:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates, config_enhanced['network']['lr'])

        for step in tqdm(range(config_enhanced['trajectory_length'])):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])
            actions = action.squeeze(-1).detach().cpu().numpy()

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(actions)
            obs = torch.tensor(obs, device=device)
            reward = torch.tensor(reward, device=device).unsqueeze(-1)
            done = torch.tensor(done, device=device)

            n_step = (j * config_enhanced['trajectory_length'] + step) * num_processes
            for info in infos:
                if 'episode' in info.keys():
                    reward_episode = info['episode']['r']
                    episode_rewards.append(reward_episode)
                    writer.add_scalar('reward', reward_episode, n_step)
                    writer.add_scalar('solved', int(info['episode']['length'] == envs.envs[0].max_steps))

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1]).detach()

        rollouts.compute_returns(next_value, config_enhanced["use_gae"], config_enhanced["gamma"],
                                 config_enhanced['gae_lambda'], config_enhanced['use_proper_time_limits'])

        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        writer.add_scalar('value loss', value_loss, n_step)
        writer.add_scalar('action loss', action_loss, n_step)
        writer.add_scalar('dist_entropy', dist_entropy, n_step)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % config_enhanced['save_interval'] == 0
                or j == num_updates - 1):
            save_path = str(writer.get_logdir())
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save(actor_critic, os.path.join(save_path, "model.pth"))

        if j % config_enhanced['log_interval'] == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                    .format(j, n_step,
                            int(n_step / (end - start)),
                            len(episode_rewards), np.mean(episode_rewards),
                            np.median(episode_rewards), np.min(episode_rewards),
                            np.max(episode_rewards), dist_entropy, value_loss,
                            action_loss))

        if (config_enhanced['evaluate_every'] is not None and len(episode_rewards) > 1
                and j % config_enhanced['evaluate_every'] == 0):
            eval_reward = evaluate(actor_critic, boxworld, config_enhanced, device)
            writer.add_scalar("eval reward", eval_reward, n_step)