Ejemplo n.º 1
0
def main(args):
    """ The main training function.

    Only works for single node (be it single or multi-GPU)

    Parameters
    ----------
    args :
        Parsed arguments
    """
    # setup
    ngpus = torch.cuda.device_count()
    if ngpus == 0:
        raise RuntimeWarning("This will not be able to run on CPU only")

    print(f"Working with {ngpus} GPUs")
    if args.optim.lower() == "ranger":
        # No warm up if ranger optimizer
        args.warm = 0

    current_experiment_time = datetime.now().strftime('%Y%m%d_%T').replace(":", "")
    args.exp_name = f"{'debug_' if args.debug else ''}{current_experiment_time}_" \
                    f"_fold{args.fold if not args.full else 'FULL'}" \
                    f"_{args.arch}_{args.width}" \
                    f"_batch{args.batch_size}" \
                    f"_optim{args.optim}" \
                    f"_{args.optim}" \
                    f"_lr{args.lr}-wd{args.weight_decay}_epochs{args.epochs}_deepsup{args.deep_sup}" \
                    f"_{'fp16' if not args.no_fp16 else 'fp32'}" \
                    f"_warm{args.warm}_" \
                    f"_norm{args.norm_layer}{'_swa' + str(args.swa_repeat) if args.swa else ''}" \
                    f"_dropout{args.dropout}" \
                    f"_warm_restart{args.warm_restart}" \
                    f"{'_' + args.com.replace(' ', '_') if args.com else ''}"
    args.save_folder = pathlib.Path(f"./runs/{args.exp_name}")
    args.save_folder.mkdir(parents=True, exist_ok=True)
    args.seg_folder = args.save_folder / "segs"
    args.seg_folder.mkdir(parents=True, exist_ok=True)
    args.save_folder = args.save_folder.resolve()
    save_args(args)
    t_writer = SummaryWriter(str(args.save_folder))

    # Create model
    print(f"Creating {args.arch}")

    model_maker = getattr(models, args.arch)

    model = model_maker(
        4, 3,
        width=args.width, deep_supervision=args.deep_sup,
        norm_layer=get_norm_layer(args.norm_layer), dropout=args.dropout)

    print(f"total number of trainable parameters {count_parameters(model)}")

    if args.swa:
        # Create the average model
        swa_model = model_maker(
            4, 3,
            width=args.width, deep_supervision=args.deep_sup,
            norm_layer=get_norm_layer(args.norm_layer))
        for param in swa_model.parameters():
            param.detach_()
        swa_model = swa_model.cuda()
        swa_model_optim = WeightSWA(swa_model)

    if ngpus > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()
    print(model)
    model_file = args.save_folder / "model.txt"
    with model_file.open("w") as f:
        print(model, file=f)

    criterion = EDiceLoss().cuda()
    metric = criterion.metric
    print(metric)

    rangered = False  # needed because LR scheduling scheme is different for this optimizer
    if args.optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay, eps=1e-4)


    elif args.optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                                    nesterov=True)

    elif args.optim == "adamw":
        print(f"weight decay argument will not be used. Default is 11e-2")
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    elif args.optim == "ranger":
        optimizer = Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        rangered = True

    # optionally resume from a checkpoint
    if args.resume:
        reload_ckpt(args, model, optimizer)

    if args.debug:
        args.epochs = 2
        args.warm = 0
        args.val = 1

    if args.full:
        train_dataset, bench_dataset = get_datasets(args.seed, args.debug, full=True)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=False, drop_last=True)

        bench_loader = torch.utils.data.DataLoader(
            bench_dataset, batch_size=1, num_workers=args.workers)

    else:

        train_dataset, val_dataset, bench_dataset = get_datasets(args.seed, args.debug, fold_number=args.fold)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=False, drop_last=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=max(1, args.batch_size // 2), shuffle=False,
            pin_memory=False, num_workers=args.workers, collate_fn=determinist_collate)

        bench_loader = torch.utils.data.DataLoader(
            bench_dataset, batch_size=1, num_workers=args.workers)
        print("Val dataset number of batch:", len(val_loader))

    print("Train dataset number of batch:", len(train_loader))

    # create grad scaler
    scaler = GradScaler()

    # Actual Train loop

    best = np.inf
    print("start warm-up now!")
    if args.warm != 0:
        tot_iter_train = len(train_loader)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda cur_iter: (1 + cur_iter) / (tot_iter_train * args.warm))

    patients_perf = []

    if not args.resume:
        for epoch in range(args.warm):
            ts = time.perf_counter()
            model.train()
            training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer,
                                 scaler, scheduler, save_folder=args.save_folder,
                                 no_fp16=args.no_fp16, patients_perf=patients_perf)
            te = time.perf_counter()
            print(f"Train Epoch done in {te - ts} s")

            # Validate at the end of epoch every val step
            if (epoch + 1) % args.val == 0 and not args.full:
                model.eval()
                with torch.no_grad():
                    validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch,
                                           t_writer, save_folder=args.save_folder,
                                           no_fp16=args.no_fp16)

                t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch)

    if args.warm_restart:
        print('Total number of epochs should be divisible by 30, else it will do odd things')
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 30, eta_min=1e-7)
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               args.epochs + 30 if not rangered else round(
                                                                   args.epochs * 0.5))
    print("start training now!")
    if args.swa:
        # c = 15, k=3, repeat = 5
        c, k, repeat = 30, 3, args.swa_repeat
        epochs_done = args.epochs
        reboot_lr = 0
        if args.debug:
            c, k, repeat = 2, 1, 2

    for epoch in range(args.start_epoch + args.warm, args.epochs + args.warm):
        try:
            # do_epoch for one epoch
            ts = time.perf_counter()
            model.train()
            training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer,
                                 scaler, save_folder=args.save_folder,
                                 no_fp16=args.no_fp16, patients_perf=patients_perf)
            te = time.perf_counter()
            print(f"Train Epoch done in {te - ts} s")

            # Validate at the end of epoch every val step
            if (epoch + 1) % args.val == 0 and not args.full:
                model.eval()
                with torch.no_grad():
                    validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer,
                                           epoch,
                                           t_writer,
                                           save_folder=args.save_folder,
                                           no_fp16=args.no_fp16, patients_perf=patients_perf)

                t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch)

                if validation_loss < best:
                    best = validation_loss
                    model_dict = model.state_dict()
                    save_checkpoint(
                        dict(
                            epoch=epoch, arch=args.arch,
                            state_dict=model_dict,
                            optimizer=optimizer.state_dict(),
                            scheduler=scheduler.state_dict(),
                        ),
                        save_folder=args.save_folder, )

                ts = time.perf_counter()
                print(f"Val epoch done in {ts - te} s")

            if args.swa:
                if (args.epochs - epoch - c) == 0:
                    reboot_lr = optimizer.param_groups[0]['lr']

            if not rangered:
                scheduler.step()
                print("scheduler stepped!")
            else:
                if epoch / args.epochs > 0.5:
                    scheduler.step()
                    print("scheduler stepped!")

        except KeyboardInterrupt:
            print("Stopping training loop, doing benchmark")
            break

    if args.swa:
        swa_model_optim.update(model)
        print("SWA Model initialised!")
        for i in range(repeat):
            optimizer = torch.optim.Adam(model.parameters(), args.lr / 2, weight_decay=args.weight_decay)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, c + 10)
            for swa_epoch in range(c):
                # do_epoch for one epoch
                ts = time.perf_counter()
                model.train()
                swa_model.train()
                current_epoch = epochs_done + i * c + swa_epoch
                training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer,
                                     current_epoch, t_writer,
                                     scaler, no_fp16=args.no_fp16, patients_perf=patients_perf)
                te = time.perf_counter()
                print(f"Train Epoch done in {te - ts} s")

                t_writer.add_scalar(f"SummaryLoss/train", training_loss, current_epoch)

                # update every k epochs and val:
                print(f"cycle number: {i}, swa_epoch: {swa_epoch}, total_cycle_to_do {repeat}")
                if (swa_epoch + 1) % k == 0:
                    swa_model_optim.update(model)
                    if not args.full:
                        model.eval()
                        swa_model.eval()
                        with torch.no_grad():
                            validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer,
                                                   current_epoch,
                                                   t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16)
                            swa_model_loss = step(val_loader, swa_model, criterion, metric, args.deep_sup, optimizer,
                                                  current_epoch,
                                                  t_writer, swa=True, save_folder=args.save_folder,
                                                  no_fp16=args.no_fp16)

                        t_writer.add_scalar(f"SummaryLoss/val", validation_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/swa", swa_model_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/overfit_swa", swa_model_loss - training_loss, current_epoch)
                scheduler.step()
        epochs_added = c * repeat
        save_checkpoint(
            dict(
                epoch=args.epochs + epochs_added, arch=args.arch,
                state_dict=swa_model.state_dict(),
                optimizer=optimizer.state_dict()
            ),
            save_folder=args.save_folder, )
    else:
        save_checkpoint(
            dict(
                epoch=args.epochs, arch=args.arch,
                state_dict=model.state_dict(),
                optimizer=optimizer.state_dict()
            ),
            save_folder=args.save_folder, )

    try:
        df_individual_perf = pd.DataFrame.from_records(patients_perf)
        print(df_individual_perf)
        df_individual_perf.to_csv(f'{str(args.save_folder)}/patients_indiv_perf.csv')
        reload_ckpt_bis(f'{str(args.save_folder)}/model_best.pth.tar', model)
        generate_segmentations(bench_loader, model, t_writer, args)
    except KeyboardInterrupt:
        print("Stopping right now!")
Ejemplo n.º 2
0
def train(args):
    """
    :param args:
    :return:
    """
    grammar = semQL.Grammar()
    sql_data, table_data, val_sql_data, val_table_data = utils.load_dataset(
        args.dataset, use_small=args.toy)

    model = IRNet(args, grammar)
    if args.cuda: model.cuda()

    # now get the optimizer
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)
    print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
    if args.lr_scheduler:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
    else:
        scheduler = None

    print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
    print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)

    if args.load_model:
        print('load pretrained model from %s' % (args.load_model))
        pretrained_model = torch.load(
            args.load_model, map_location=lambda storage, loc: storage)
        pretrained_modeled = copy.deepcopy(pretrained_model)
        for k in pretrained_model.keys():
            if k not in model.state_dict().keys():
                del pretrained_modeled[k]

        model.load_state_dict(pretrained_modeled)

    model.word_emb = utils.load_word_emb(args.glove_embed_path)
    # begin train

    model_save_path = utils.init_log_checkpoint_path(args)
    utils.save_args(args, os.path.join(model_save_path, 'config.json'))
    best_dev_acc = .0

    try:
        with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
            for epoch in tqdm.tqdm(range(args.epoch)):
                if args.lr_scheduler:
                    scheduler.step()
                epoch_begin = time.time()
                loss = utils.epoch_train(
                    model,
                    optimizer,
                    args.batch_size,
                    sql_data,
                    table_data,
                    args,
                    loss_epoch_threshold=args.loss_epoch_threshold,
                    sketch_loss_coefficient=args.sketch_loss_coefficient)
                epoch_end = time.time()
                json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc(
                    model,
                    args.batch_size,
                    val_sql_data,
                    val_table_data,
                    beam_size=args.beam_size)
                # acc = utils.eval_acc(json_datas, val_sql_data)

                if acc > best_dev_acc:
                    utils.save_checkpoint(
                        model, os.path.join(model_save_path,
                                            'best_model.model'))
                    best_dev_acc = acc
                utils.save_checkpoint(
                    model,
                    os.path.join(model_save_path, '{%s}_{%s}.model') %
                    (epoch, acc))

                log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % (
                    epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin)
                tqdm.tqdm.write(log_str)
                epoch_fd.write(log_str)
                epoch_fd.flush()
    except Exception as e:
        # Save model
        utils.save_checkpoint(model,
                              os.path.join(model_save_path, 'end_model.model'))
        print(e)
        tb = traceback.format_exc()
        print(tb)
    else:
        utils.save_checkpoint(model,
                              os.path.join(model_save_path, 'end_model.model'))
        json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc(
            model,
            args.batch_size,
            val_sql_data,
            val_table_data,
            beam_size=args.beam_size)
        # acc = utils.eval_acc(json_datas, val_sql_data)

        print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (
            sketch_acc,
            acc,
            acc,
        ))
Ejemplo n.º 3
0
def train(args, model, optimizer, bert_optimizer, data):
    '''
    :param args:
    :param model:
    :param data:
    :param grammar:
    :return:
    '''

    print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
    print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)

    if args.load_model and not args.resume:
        print('load pretrained model from %s' % (args.load_model))
        pretrained_model = torch.load(
            args.load_model, map_location=lambda storage, loc: storage)
        pretrained_modeled = copy.deepcopy(pretrained_model)
        for k in pretrained_model.keys():
            if k not in model.state_dict().keys():
                del pretrained_modeled[k]

        model.load_state_dict(pretrained_modeled)

    # ==============data==============
    if args.interaction_level:
        batch_size = 1
        train_batchs, train_sample_batchs = data.get_interaction_batches(
            batch_size,
            sample_num=args.train_evaluation_size,
            use_small=args.toy)
        valid_batchs = data.get_all_interactions(data.val_sql_data,
                                                 data.val_table_data,
                                                 _type='test',
                                                 use_small=args.toy)

    else:
        batch_size = args.batch_size
        train_batchs = data.get_utterance_batches(batch_size)
        valid_batchs = data.get_all_utterances(data.val_sql_data)
    print(len(train_batchs), len(train_sample_batchs), len(valid_batchs))
    start_epoch = 1
    best_question_match = .0
    lr = args.initial_lr
    stage = 1

    if args.resume:
        model_save_path = utils.init_log_checkpoint_path(args)
        current_w = torch.load(
            os.path.join(model_save_path, args.current_model_name))
        best_w = torch.load(os.path.join(model_save_path,
                                         args.best_model_name))
        best_question_match = best_w['question_match']
        start_epoch = current_w['epoch'] + 1
        lr = current_w['lr']
        utils.adjust_learning_rate(optimizer, lr)
        stage = current_w['stage']
        model.load_state_dict(current_w['state_dict'])
        # 如果中断点恰好为转换stage的点
        if start_epoch - 1 in args.stage_epoch:
            stage += 1
            lr /= args.lr_decay
            utils.adjust_learning_rate(optimizer, lr)
            model.load_state_dict(best_w['state_dict'])
        print("=> Loading resume model from epoch {} ...".format(start_epoch -
                                                                 1))

    # model.word_emb = utils.load_word_emb(args.glove_embed_path,use_small=args.use_small)
    # begin train

    model_save_path = utils.init_log_checkpoint_path(args)
    utils.save_args(args, os.path.join(model_save_path, 'config.json'))
    file_mode = 'a' if args.resume else 'w'
    log = Logger(os.path.join(args.save, args.logfile), file_mode)
    # log_pred_gt = Logger(os.path.join(args.save, args.log_pred_gt), file_mode)

    with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:

        for epoch in range(start_epoch, args.epoch + 1):
            epoch_begin = time.time()
            # model.set_dropout(args.dropout_amount)

            model.dropout_ratio = args.dropout_amount

            if args.interaction_level:
                loss = utils.epoch_train_with_interaction(
                    epoch, log, model, optimizer, bert_optimizer, train_batchs,
                    args)
                #loss = 2.
            else:
                pass

            model.dropout_ratio = 0.
            # model.set_dropout(0.)
            epoch_end = time.time()

            s = time.time()
            sample_sketch_acc, sample_lf_acc, sample_interaction_lf_acc = utils.epoch_acc_with_interaction(
                epoch,
                model,
                train_sample_batchs,
                args,
                beam_size=args.beam_size,
                use_predicted_queries=True)
            log_str = '[Epoch: %d(sample predicted),Sample ratio[%d]: %f], Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Train time: %f, Sample predict time: %f\n' % (
                epoch, len(train_sample_batchs),
                len(train_sample_batchs) / len(train_batchs),
                sample_sketch_acc, sample_lf_acc, sample_interaction_lf_acc,
                epoch_end - epoch_begin, time.time() - s)
            print(log_str)
            log.put(log_str)
            epoch_fd.write(log_str)
            epoch_fd.flush()

            # s = time.time()
            #
            # gold_sketch_acc,gold_lf_acc ,gold_interaction_lf_acc = utils.epoch_acc_with_interaction(epoch,model, valid_batchs,args,beam_size=args.beam_size)
            #
            # log_str = '[Epoch: %d(gold)], Loss: %f, Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Gold predict time: %f\n' % (
            #     epoch, loss, gold_sketch_acc, gold_lf_acc, gold_interaction_lf_acc,time.time()-s)
            # print(log_str)
            # log.put(log_str)
            # epoch_fd.write(log_str)
            # epoch_fd.flush()

            s = time.time()
            valid_jsonf, pred_sketch_acc, pred_lf_acc, pred_interaction_lf_acc = utils.epoch_acc_with_interaction_save_json(
                epoch,
                model,
                valid_batchs,
                args,
                beam_size=args.beam_size,
                use_predicted_queries=True)

            question_match, interaction_match = utils.semQL2SQL_question_and_interaction_match(
                valid_jsonf, args)

            log_str = '[Epoch: %d(predicted)], Loss: %f, lr: %.3ef, Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Question Match : %f, Interaction Macth : %f, Predicted predict time: %f\n\n' % (
                epoch, loss, optimizer.param_groups[0]["lr"], pred_sketch_acc,
                pred_lf_acc, pred_interaction_lf_acc, question_match,
                interaction_match, time.time() - s)
            print(log_str)
            log.put(log_str)
            epoch_fd.write(log_str)
            epoch_fd.flush()

            state = {
                "state_dict": model.state_dict(),
                "epoch": epoch,
                "question_match": question_match,
                "interaction_match": interaction_match,
                "lr": lr,
                'stage': stage
            }

            current_w_name = os.path.join(
                model_save_path, '{}_{:.3f}.pth'.format(epoch, question_match))
            best_w_name = os.path.join(model_save_path, args.best_model_name)
            current_w_name2 = os.path.join(model_save_path,
                                           args.current_model_name)

            utils.save_ckpt(state, best_question_match < question_match,
                            current_w_name, best_w_name, current_w_name2)
            best_question_match = max(best_question_match, question_match)
            if epoch in args.stage_epoch:
                stage += 1
                lr /= args.lr_decay
                best_w_name = os.path.join(model_save_path,
                                           args.best_model_name)
                model.load_state_dict(torch.load(best_w_name)['state_dict'])
                print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
                utils.adjust_learning_rate(optimizer, lr)

    log.put("Finished training!")
    log.close()