Beispiel #1
0
def main():
    timer = Timer()
    timer.start('Load word2vec models...')
    vocab = load_vocab(config.VOCAB_DATA)
    embeddings = get_trimmed_w2v_vectors(config.W2V_DATA)
    timer.stop()

    timer.start('Load data...')
    train = process_data(opt.train, vocab)
    if opt.val is not None:
        if opt.val != '1vs9':
            validation = process_data(opt.val, vocab)
        else:
            validation, train = train.one_vs_nine()
    else:
        validation = None

    if opt.test is not None:
        test = process_data(opt.test, vocab)
    else:
        test = None
    timer.stop()

    timer.start('Build model...')
    model = CnnModel(embeddings=embeddings)
    model.build()
    timer.stop()

    timer.start('Train model...')
    epochs = opt.e
    batch_size = opt.b
    early_stopping = True if opt.p != 0 else False
    patience = opt.p
    pre_train = opt.pre if opt.pre != '' else None
    model_name = opt.name

    model.train(
        model_name,
        train=train,
        validation=validation,
        epochs=epochs,
        batch_size=batch_size,
        early_stopping=early_stopping,
        patience=patience,
        cont=pre_train,
    )
    timer.stop()

    if test is not None:
        timer.start('Test model...')
        preds = model.predict(test, model_name)
        labels = test.labels

        p, r, f1, _ = precision_recall_fscore_support(labels,
                                                      preds,
                                                      average='binary')
        print('Testing result:P=\t{}\tR={}\tF1={}'.format(p, r, f1))
        timer.stop()
Beispiel #2
0
def main():
    with timer('load data'):
        df = pd.read_csv(TRAIN_PATH)
        df = df[df.Image != "ID_6431af929"].reset_index(drop=True)
        df.loc[df.pre_SOPInstanceUID=="ID_6431af929", "pre1_SOPInstanceUID"] = df.loc[
            df.pre_SOPInstanceUID=="ID_6431af929", "Image"]
        df.loc[df.post_SOPInstanceUID == "ID_6431af929", "post1_SOPInstanceUID"] = df.loc[
            df.post_SOPInstanceUID == "ID_6431af929", "Image"]
        df.loc[df.prepre_SOPInstanceUID == "ID_6431af929", "pre2_SOPInstanceUID"] = df.loc[
            df.prepre_SOPInstanceUID == "ID_6431af929", "pre1_SOPInstanceUID"]
        df.loc[df.postpost_SOPInstanceUID == "ID_6431af929", "post2_SOPInstanceUID"] = df.loc[
            df.postpost_SOPInstanceUID == "ID_6431af929", "post1_SOPInstanceUID"]
        y = df[TARGET_COLUMNS].values
        df = df[["Image", "pre1_SOPInstanceUID", "post1_SOPInstanceUID", "pre2_SOPInstanceUID", "post2_SOPInstanceUID"]]
        gc.collect()

    with timer('preprocessing'):
        train_augmentation = Compose([
            CenterCrop(512 - 50, 512 - 50, p=1.0),
            HorizontalFlip(p=0.5),
            OneOf([
                ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                GridDistortion(p=0.5),
                OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            ], p=0.5),
            Rotate(limit=30, border_mode=0, p=0.7),
            Resize(img_size, img_size, p=1)
        ])

        train_dataset = RSNADataset(df, y, img_size, IMAGE_PATH, id_colname=ID_COLUMNS,
                                    transforms=train_augmentation, black_crop=False, subdural_window=True, user_window=2)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
        del df, train_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES, encoder="se_resnext50_32x4d", pretrained="imagenet", pool_type="avg")
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss(weight=torch.FloatTensor([2, 1, 1, 1, 1, 1]).cuda())
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-4)
        model = torch.nn.DataParallel(model)

    with timer('train'):
        for epoch in range(1, epochs + 1):
            if epoch == 5:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] * 0.1
            seed_torch(SEED + epoch)

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            torch.save(model.module.state_dict(), 'models/{}_ep{}.pth'.format(EXP_ID, epoch))
def main():
    with timer('load data'):
        df = pd.read_csv(TEST_PATH)
        df[['ID', 'Image', 'Diagnosis']] = df['ID'].str.split('_', expand=True)
        df = df[['Image', 'Diagnosis', 'Label']]
        df.drop_duplicates(inplace=True)
        df = df.pivot(index='Image', columns='Diagnosis',
                      values='Label').reset_index()
        df['Image'] = 'ID_' + df['Image']
        df = df[["Image"]]
        ids = df["Image"].values
        gc.collect()

    with timer('preprocessing'):
        test_augmentation = Compose([
            CenterCrop(512 - 50, 512 - 50, p=1.0),
            Resize(img_size, img_size, p=1)
        ])

        test_dataset = RSNADatasetTest(df,
                                       img_size,
                                       IMAGE_PATH,
                                       id_colname=ID_COLUMNS,
                                       transforms=test_augmentation,
                                       black_crop=False,
                                       subdural_window=True,
                                       conc_type="concat_all",
                                       conc_type2="concat_prepost",
                                       n_tta=N_TTA)
        test_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=16,
                                 pin_memory=True)
        del df, test_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES,
                         encoder="se_resnext50_32x4d",
                         pretrained="imagenet",
                         pool_type="avg")
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        model = torch.nn.DataParallel(model)

    with timer('predict'):
        pred = predict(model, test_loader, device, n_tta=N_TTA)
        pred = np.clip(pred, 1e-6, 1 - 1e-6)

    with timer('sub'):
        sub = pd.DataFrame(pred, columns=TARGET_COLUMNS)
        sub["ID"] = ids
        sub = sub.set_index("ID")
        sub = sub.unstack().reset_index()
        sub["ID"] = sub["ID"] + "_" + sub["level_0"]
        sub = sub.rename(columns={0: "Label"})
        sub = sub.drop("level_0", axis=1)
        LOGGER.info(sub.head())
        sub.to_csv("../output/{}_sub_st2.csv".format(EXP_ID), index=False)
Beispiel #4
0
def main():
    with timer('load data'):
        df = pd.read_csv(TRAIN_PATH)[:10]
        df = df[df.Image != "ID_6431af929"].reset_index(drop=True)
        df.loc[df.pre_SOPInstanceUID == "ID_6431af929",
               "pre1_SOPInstanceUID"] = df.loc[df.pre_SOPInstanceUID ==
                                               "ID_6431af929", "Image"]
        df.loc[df.post_SOPInstanceUID == "ID_6431af929",
               "post1_SOPInstanceUID"] = df.loc[df.post_SOPInstanceUID ==
                                                "ID_6431af929", "Image"]
        df = df[["Image", "pre1_SOPInstanceUID", "post1_SOPInstanceUID"]]
        ids = df["Image"].values
        gc.collect()

    with timer('preprocessing'):
        test_augmentation = Compose([
            CenterCrop(512 - 50, 512 - 50, p=1.0),
            Resize(img_size, img_size, p=1)
        ])

        test_dataset = RSNADatasetTest(df,
                                       img_size,
                                       IMAGE_PATH,
                                       id_colname=ID_COLUMNS,
                                       transforms=test_augmentation,
                                       black_crop=False,
                                       three_window=True,
                                       rescaling=False,
                                       n_tta=N_TTA)
        test_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=16,
                                 pin_memory=True)
        del df, test_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES,
                         encoder="se_resnext50_32x4d",
                         pretrained="imagenet",
                         pool_type="avg")
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        model = torch.nn.DataParallel(model)

    with timer('predict'):
        pred = predict(model, test_loader, device, n_tta=N_TTA)
        pred = np.clip(pred, 1e-6, 1 - 1e-6)

    with timer('sub'):
        sub = pd.DataFrame(pred, columns=TARGET_COLUMNS)
        sub["ID"] = ids
        sub = sub.set_index("ID")
        sub = sub.unstack().reset_index()
        sub["ID"] = sub["ID"] + "_" + sub["level_0"]
        sub = sub.rename(columns={0: "Label"})
        sub = sub.drop("level_0", axis=1)
        LOGGER.info(sub.head())
        sub.to_csv("../output/{}_train.csv".format(EXP_ID), index=False)
Beispiel #5
0
def main():
    with timer('load data'):
        df = pd.read_csv(TEST_PATH)
        df["post_SOPInstanceUID"] = df["post_SOPInstanceUID"].fillna(
            df["SOPInstanceUID"])
        df["pre_SOPInstanceUID"] = df["pre_SOPInstanceUID"].fillna(
            df["SOPInstanceUID"])
        df = df[["Image", "pre_SOPInstanceUID", "post_SOPInstanceUID"]]
        ids = df["Image"].values
        pre_ids = df["pre_SOPInstanceUID"].values
        pos_ids = df["post_SOPInstanceUID"].values
        gc.collect()

    with timer('preprocessing'):
        test_augmentation = Compose([
            CenterCrop(512 - 50, 512 - 50, p=1.0),
            Resize(img_size, img_size, p=1)
        ])

        test_dataset = RSNADatasetTest(df,
                                       img_size,
                                       IMAGE_PATH,
                                       id_colname=ID_COLUMNS,
                                       transforms=test_augmentation,
                                       black_crop=False,
                                       subdural_window=True,
                                       n_tta=N_TTA)
        test_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=16,
                                 pin_memory=True)
        del df, test_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES,
                         encoder="se_resnext50_32x4d",
                         pretrained="imagenet",
                         pool_type="avg")
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        model = torch.nn.DataParallel(model)

    with timer('predict'):
        pred = predict(model, test_loader, device, n_tta=N_TTA)
        pred = np.clip(pred, 1e-6, 1 - 1e-6)

    with timer('sub'):
        sub = pd.DataFrame(pred, columns=TARGET_COLUMNS)
        sub["ID"] = ids
        sub["PRE_ID"] = pre_ids
        sub["POST_ID"] = pos_ids
        sub = postprocess_multitarget(sub)
        LOGGER.info(sub.head())
        sub.to_csv("../output/{}_sub_st2.csv".format(EXP_ID), index=False)
def main():
    with timer('load data'):
        path = glob.glob("../input_ext/*/*/*/*.dcm")
        df = pd.DataFrame({"Image": path})
        df = df[["Image"]]
        ids = df["Image"].values
        gc.collect()

    with timer('preprocessing'):
        test_augmentation = Compose([
            CenterCrop(512 - 50, 512 - 50, p=1.0),
            Resize(img_size, img_size, p=1)
        ])

        test_dataset = RSNADatasetTest(df,
                                       img_size,
                                       IMAGE_PATH,
                                       id_colname=ID_COLUMNS,
                                       transforms=test_augmentation,
                                       black_crop=False,
                                       subdural_window=True,
                                       n_tta=N_TTA,
                                       img_type="",
                                       external=True)
        test_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=0,
                                 pin_memory=True)
        del df, test_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES,
                         encoder="se_resnext50_32x4d",
                         pretrained="imagenet",
                         pool_type="avg")
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        model = torch.nn.DataParallel(model)

    with timer('predict'):
        pred, is_dicoms = predict_external(model,
                                           test_loader,
                                           device,
                                           n_tta=N_TTA)
        pred = np.clip(pred, 1e-6, 1 - 1e-6)

    with timer('sub'):
        sub = pd.DataFrame(pred, columns=TARGET_COLUMNS)
        sub["is_dicom"] = is_dicoms.reshape(-1)
        sub["Image"] = ids.reshape(-1)
        LOGGER.info(sub.head())
        sub.to_csv("../input_ext/{}_externalv2.csv".format(EXP_ID),
                   index=False)
def main():
    timer = Timer()
    timer.start('Load word2vec models...')
    vocab = load_vocab(config.VOCAB_DATA)
    embeddings = get_trimmed_w2v_vectors(config.W2V_DATA)
    timer.stop()

    timer.start('Build model...')
    model = CnnModel(embeddings=embeddings)
    model.build()
    model.restore_session(opt.pre)
    timer.stop()

    # Define app
    app = Flask(__name__)
    CORS(app)

    @app.route('/process', methods=['POST'])
    def process():
        data = request.get_json()
        try:
            words = [[
                vocab[w] if w in vocab else vocab['$UNK$']
                for w in parse_raw_data(s)
            ] for s in data['input']]
            test = Dataset(words=words)
        except:
            test = None
            abort(400)

        job_id = timer.start('Process {} example'.format(len(data['input'])))
        y_pred = model.predict(test, opt.pre, pred_class=False)

        ret = {'output': [i.tolist() for i in y_pred]}
        timer.stop(job_id)

        return jsonify(ret)

    app.run()
Beispiel #8
0
def main():
    with timer('load data'):
        df = pd.read_csv(TRAIN_PATH)
        df["loc_x"] = df["loc_x"] / 100
        df["loc_y"] = df["loc_y"] / 100
        y = df[TARGET_COLUMNS].values
        df = df[[ID_COLUMNS]]
        gc.collect()

    with timer("split data"):
        folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=0).split(df, y)
        for n_fold, (train_index, val_index) in enumerate(folds):
            train_df = df.loc[train_index]
            val_df = df.loc[val_index]
            y_train = y[train_index]
            y_val = y[val_index]
            if n_fold == fold_id:
                break

    with timer('preprocessing'):
        train_augmentation = Compose([
            HorizontalFlip(p=0.5),
            OneOf([
                ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                GridDistortion(p=0.5),
                OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            ], p=0.5),
            RandomBrightnessContrast(p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
            Resize(img_size, img_size, p=1)
        ])
        val_augmentation = Compose([
            Resize(img_size, img_size, p=1)
        ])

        train_dataset = KDDataset(train_df, y_train, img_size, IMAGE_PATH, id_colname=ID_COLUMNS,
                                  transforms=train_augmentation)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

        val_dataset = KDDataset(val_df, y_val, img_size, IMAGE_PATH, id_colname=ID_COLUMNS,
                                transforms=val_augmentation)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
        del df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES, encoder="se_resnext50_32x4d",
                         pretrained="../input/pytorch-pretrained-models/se_resnext50_32x4d-a260b3a4.pth",
                         pool_type="avg")
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-4)

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)

    with timer('train'):
        best_score = 0
        for epoch in range(1, epochs + 1):
            seed_torch(SEED + epoch)

            if epoch == epochs - 3:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] * 0.1

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, N_CLASSES)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            y_pred, target, val_loss = validate(model, val_loader, criterion, device, N_CLASSES)
            score = roc_auc_score(target, y_pred)
            LOGGER.info('Mean val loss: {}'.format(round(val_loss, 5)))
            LOGGER.info('val score: {}'.format(round(score, 5)))

            if score > best_score:
                best_score = score
                np.save("y_pred.npy", y_pred)
                torch.save(model.state_dict(), save_path)

        np.save("target.npy", target)

    with timer('predict'):
        test_df = pd.read_csv(TEST_PATH)
        test_ids = test_df["id"].values

        test_augmentation = Compose([
            Resize(img_size, img_size, p=1)
        ])
        test_dataset = KDDatasetTest(test_df, img_size, TEST_IMAGE_PATH, id_colname=ID_COLUMNS,
                                     transforms=test_augmentation, n_tta=2)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

        model.load_state_dict(torch.load(save_path))

        pred = predict(model, test_loader, device, N_CLASSES, n_tta=2)
        print(pred.shape)
        results = pd.DataFrame({"id": test_ids,
                                "is_star": pred.reshape(-1)})

        results.to_csv("results.csv", index=False)
Beispiel #9
0
    transform = ConstantPad(shape=(3, 700, 600), padding_mode='reflect')
    dataset = FridgeVoterDataset('dataset', 'data.json', transform=transform)
    train_len = int(len(dataset) * 0.8)
    train_set, test_set = random_split(
        dataset, [train_len, len(dataset) - train_len])

    trainloader = DataLoader(train_set,
                             batch_size=4,
                             shuffle=True,
                             num_workers=num_workers)
    testloader = DataLoader(test_set,
                            batch_size=4,
                            shuffle=True,
                            num_workers=num_workers)

    model = CnnModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    running_loss = 0.0
    for epoch in range(epochs):
        for i, (image, vote) in enumerate(trainloader, 0):
            # image, vote = data
            labels = FridgeVoterDataset.vote_to_class(vote)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(image)
            loss = criterion(outputs, labels)
Beispiel #10
0
def train():
    #initial config
    config = Config()
    if config.init_from:
        ckpt = tf.train.get_checkpoint_state(config.init_from)
        assert ckpt, 'No Checketpoint Found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'
        with open(os.path.join(config.save_dir, 'config.pkl'), 'r') as rf:
            config = cPickle.load(rf)
    else:
        if not os.path.isdir(config.save_dir):
            os.makedirs(config.save_dir)
        with open(os.path.join(config.save_dir, 'config.pkl'), 'wb') as wf:
            cPickle.dump(config, wf)
        if not os.path.isdir(config.log_dir):
            os.makedirs(config.log_dir)

    dataloader = dataLoader(config.batch_size,config.vocab_size,config.data_dir,\
            config.seq_length,config.vali_rate)

    gpu_option = tf.GPUOptions(allow_growth=True)
    sessconfig = tf.ConfigProto(gpu_options=gpu_option)
    with tf.Session(config=sessconfig) as sess:
        initializer = tf.random_uniform_initializer(-1 * config.init_scale,
                                                    1 * config.init_scale)
        #with tf.variable_scope('model',reuse=None,initializer=initializer):
        model = CnnModel(config)

        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
            os.path.join(config.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        if config.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        #train
        num_batches = dataloader.train_num
        for e in range(config.num_epoch):
            lr_decay = config.lr_decay**max(e - config.max_decay_epoch, 0.0)
            model.assign_new_lr(sess, config.lr * lr_decay)
            for i, (x, y, mask) in enumerate(dataloader.get_batches('train')):
                start = time.time()
                feed = {model.input_data: x, model.targets: y}
                #state = sess.run(model._initial_state)
                #for j , (c,h) in enumerate(model._initial_state):
                #    feed[c]=state[j].c
                #    feed[h]=state[j].h
                loss, acc, summ, _ = sess.run(
                    [model.cost, model.accuracy, summaries, model.train_op],
                    feed)
                #print len(midoput)
                #print midoput[0].shape
                writer.add_summary(summ, e * config.batch_size + i)
                end = time.time()
                print("{}/{} (epoch {}), train_loss={:.5f},acc={:.5f},time/batch ={:.4f}"\
                        .format(e * num_batches + i,\
                        config.num_epoch*num_batches,e,loss, acc,end - start))
                if (config.num_epoch*num_batches+i)%config.check_point_every==0\
                        or (e==config.num_epoch-1 and i ==num_batches-1):
                    checkpoint_path = os.path.join(config.save_dir,
                                                   'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * num_batches + i)
                    validation(dataloader, sess, model)
Beispiel #11
0
def main():
    # device
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # args
    parser = build_argparse()
    args = parser.parse_args()
    check_argparse(args)

    # data
    print('\n-------- Data Preparing --------\n')

    train_dataloader, val_dataloader, test_dataloader = build_train_val_test_dataset(
        args)

    print('\n-------- Data Preparing Done! --------\n')

    print('\n-------- Preparing Model --------\n')
    # model
    if args.task_type == 'DistortedMNIST':
        if args.model_name == 'ST-CNN':
            stn = BaseStn(model_name=args.model_name,
                          trans_type=args.trans_type,
                          input_ch=args.input_ch,
                          input_length=args.input_length)
            stn.load_state_dict(
                torch.load(
                    '/home/jarvis1121/AI/Rico_Repo/Spatial-Transformer-Network/model_save/stn_11_DistortedMNIST_RTS_ST-CNN.pth'
                ))

            base_cnn = BaseCnnModel(input_length=args.input_length,
                                    gap=args.gap)
            model = StModel(base_stn=stn, base_nn_model=base_cnn)

            # pass to CUDA device
            model = model.to(device)

            criterion = nn.CrossEntropyLoss()

            optimizer = optim.SGD(model.parameters(), lr=0.01)
            scheduler = build_scheduler(optimizer)

        elif args.model_name == 'ST-FCN':
            stn = BaseStn(model_name=args.model_name,
                          trans_type=args.trans_type,
                          input_ch=args.input_ch,
                          input_length=args.input_length)
            stn.load_state_dict(
                torch.load(
                    '/home/jarvis1121/AI/Rico_Repo/Spatial-Transformer-Network/model_save/stn_11_DistortedMNIST_RTS_ST-CNN.pth'
                ))
            base_fcn = BaseFcnModel(input_length=args.input_length)
            model = StModel(base_stn=stn, base_nn_model=base_fcn)

            # pass to CUDA device
            model = model.to(device)

            criterion = nn.CrossEntropyLoss()

            optimizer = optim.SGD(model.parameters(), lr=0.01)
            scheduler = build_scheduler(optimizer)

        elif args.model_name == 'CNN':
            model = CnnModel()

            # pass to CUDA device
            model = model.to(device)

            criterion = nn.CrossEntropyLoss()

            optimizer = optim.SGD(model.parameters(), lr=0.01)
            scheduler = build_scheduler(optimizer)

        else:
            model = FcnModel()

            # pass to CUDA device
            model = model.to(device)

            criterion = nn.CrossEntropyLoss()

            optimizer = optim.SGD(model.parameters(), lr=0.01)
            scheduler = build_scheduler(optimizer)

    elif args.task_type == 'MNISTAddition':
        #TODO
        pass

    else:
        #TODO
        pass

    print('\n-------- Preparing Model Done! --------\n')

    # train
    print('\n-------- Starting Training --------\n')
    # prepare the tensorboard
    writer = SummaryWriter(f'runs/trial_{args.exp}')

    for epoch in range(
            args.epoch
    ):  #TODO paper uses 150*1000 iterations ~ 769 epoch in batch_size = 256
        train_running_loss = 0.0
        print(f'\n---The {epoch+1}-th epoch---\n')
        print('[Epoch, Batch] : Loss')

        # TRAINING LOOP
        print('---Training Loop begins---')
        for i, data in enumerate(train_dataloader, start=0):
            # move CUDA device
            input, target = data[0].to(device), data[1].to(device)
            # print("input size: ", input.size())

            optimizer.zero_grad()
            output = model(input)
            loss = criterion(output, target)

            loss.backward()
            optimizer.step()
            scheduler.step()

            train_running_loss += loss.item()
            writer.add_scalar('Averaged loss', loss.item(), 196 * epoch + i)
            writer.add_scalar('ST Gradients Norm', model.norm, 196 * epoch + i)
            if i % 20 == 19:
                print(f"[{epoch+1}, {i+1}]: %.3f" % (train_running_loss / 20))
                train_running_loss = 0.0
            elif i == 195:
                print(f"[{epoch+1}, {i+1}]: %.3f" % (train_running_loss / 16))
        print('---Training Loop ends---')

        # catch the transformed image though ST, after one epoch
        with torch.no_grad():
            # number of images to show
            n = 6
            origi_img = input[:n, ...].clone().detach()  #(4, C, H, W)
            trans_img = stn(origi_img)  #(4, C, H, W)
            img = torch.cat((origi_img, trans_img), dim=0)  #(4+4, C, H, W)
            img = make_grid(img, nrow=n)
            writer.add_image(f"Original-Up, ST-Down images in epoch_{epoch+1}",
                             img)

        # VALIDATION LOOP
        with torch.no_grad():
            val_run_loss = 0.0
            print('---Validaion Loop begins---')
            batch_count = 0
            total_count = 0
            correct_count = 0
            for i, data in enumerate(val_dataloader, start=0):
                input, target = data[0].to(device), data[1].to(device)

                output = model(input)
                loss = criterion(output, target)

                _, predicted = torch.max(output, 1)

                val_run_loss += loss.item()
                batch_count += 1
                total_count += target.size(0)

                correct_count += (predicted == target).sum().item()

            accuracy = (100 * correct_count / total_count)
            val_run_loss = val_run_loss / batch_count

            writer.add_scalar('Validation accuracy', accuracy, epoch)
            writer.add_scalar('Validation loss', val_run_loss, epoch)

            print(f"Loss of {epoch+1} epoch is %.3f" % (val_run_loss))
            print(f"Accuracy is {accuracy} %")

            print('---Validaion Loop ends---')
    writer.close()
    print('\n-------- End Training --------\n')

    print('\n-------- Saving Model --------\n')

    savepath = f'/home/jarvis1121/AI/Rico_Repo/Spatial-Transformer-Network/model_save/{str(args.exp)}_{str(args.task_type)}_{str(args.trans_type)}_{str(args.model_name)}.pth'
    torch.save(model.state_dict(), savepath)

    print('\n-------- Saved --------\n')
    print(f'\n== Trial {args.exp} finished ==\n')
Beispiel #12
0
from model import CnnModel

CnnModel.train(
    vocab_dir='/Users/chengyiwu/GitHub/nlp/vocab.txt',
    categories=['正面', '负面', '中立'],
    save_dir='/Users/chengyiwu/GitHub/nlp/sentiment/textcnn',
    train_dir=
    '/Users/chengyiwu/GitHub/nlp2/text-classification-cnn-rnn-sentiment/data/cnews/cnews.train.txt',
    val_dir=
    '/Users/chengyiwu/GitHub/nlp2/text-classification-cnn-rnn-sentiment/data/cnews/cnews.val.txt',
    config=None,
    full=True,
    num_epochs=1)

CnnModel.train(
    vocab_dir='/Users/chengyiwu/GitHub/nlp/vocab.txt',
    categories=['正面', '负面', '中立'],
    save_dir='/Users/chengyiwu/GitHub/nlp/sentiment/textcnn',
    train_dir=
    '/Users/chengyiwu/GitHub/nlp2/text-classification-cnn-rnn-sentiment/data/cnews/cnews.train.txt',
    val_dir=
    '/Users/chengyiwu/GitHub/nlp2/text-classification-cnn-rnn-sentiment/data/cnews/cnews.val.txt',
    config=None,
    full=False,
    num_epochs=1)