def demo_from_best_model(resnet_layer, pretrained, num_classes, path):

    assert resnet_layer == 18 or resnet_layer == 50

    net_best = ResNet(layer_num=resnet_layer, pretrained=pretrained, num_classes=num_classes)
    net_best = net_best.to(device)
    net_best.load_state_dict(torch.load(path))
    net_best.eval()
    best_acc = save_confusion_matrix(net_best, val_loader, 'backup_demo/cm_best.png')
    print('test_best_accuracy = %.2f' % best_acc)
def main():
    # ************************************ DADOS ***************************************************
    padroniza_imagem = 300
    tamanho_da_entrada = (224, 224)
    arquivo = "./imagens/raposa.jpg"
    cor = (0, 255, 0)

    # Operacoes de preprocessamento e augumentacao
    composicao_de_transformacao = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # ************************************* REDE ************************************************
    modelo = ResNet(1000, True)
    modelo.eval()

    # Abre a imagem
    imagem_original = np.asarray(Image.open(arquivo))
    imagem = imagem_original.copy()

    # Obtem as coordenadas da imagem
    (H, W) = imagem.shape[:2]
    r = padroniza_imagem / W
    dim_final = (padroniza_imagem, int(H * r))
    imagem = cv2.resize(imagem, dim_final, interpolation=cv2.INTER_AREA)

    # Area da regiao de interesse
    ROI = (150, 150)  #(H,W)

    # Lista de regioes de interesse (rois) e coods (coordenadas)
    rois = []
    coods = []

    # Execucao da funcao de piramede
    for nova_imagem in util.image_pyramid(imagem, escala=1.2):
        # Fator de escala entre a imagem original e a nova imagem gerada
        fator_escalar = W / float(nova_imagem.shape[1])

        # Executa a operacao de deslizamento de janela
        for (x, y, regiao) in util.sliding_window(nova_imagem,
                                                  size=ROI,
                                                  stride=8):

            # Condicao de parada
            key = cv2.waitKey(1) & 0xFF
            if (key == ord("q")):
                break

            if regiao.shape[0] != ROI[0] or regiao.shape[1] != ROI[1]:
                continue

            # Obtem as coordenadas da ROI com relacao aa image
            x_r = int(x * fator_escalar)
            w_r = int(fator_escalar * ROI[1])

            y_r = int(y * fator_escalar)
            h_r = int(fator_escalar * ROI[0])

            # Obtem o ROI e realiza a transformacao necessaria para o treinamento
            roi = cv2.resize(regiao, tamanho_da_entrada)
            roi = np.asarray(roi)
            rois.append(roi)

            # Obtem as coordenadas (x1, y1, x2, y2)
            coods.append((x_r, y_r, x_r + w_r, y_r + h_r))

            # Utiliza uma copia da imagem
            copia = nova_imagem.copy()
            # Imprime um retangulo na imagem de acordo com a posicao
            cv2.rectangle(copia, (x, y), (x + ROI[1], y + ROI[0]), cor, 2)
            # Mostra o resultado na janela
            cv2.imshow("Janela", copia[:, :, ::-1])

            # Atraso no loop
            time.sleep(0.01)

    # Fechar todas as janelas abertas
    cv2.destroyAllWindows()

    #rois = np.array(rois, dtype="float32") # transform to torch tensor
    dataset = DataSet(rois, coods, composicao_de_transformacao)
    size = len(dataset)
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               shuffle=True,
                                               batch_size=size)

    print("Cópias: ", size)
    with torch.no_grad():
        for _, (X, y) in enumerate(train_loader):
            # Classificacoes de todas as copias das imagens
            resultado = modelo.forward(X)

            # Obtem os melhores resultados por imagem
            confs, indices_dos_melhores_resultados = torch.max(resultado, 1)
            classe, _ = torch.mode(indices_dos_melhores_resultados.flatten(),
                                   -1)

            # Mascara
            mascara = [
                True if item == classe else False
                for item in indices_dos_melhores_resultados
            ]

            # Selecao de boxes
            boxes = []
            for i in range(size):
                if mascara[i] == True:
                    boxes.append(coods[i])

            # Realiza operacao de non_max_suppression
            boxes = util.non_max_suppression(np.asarray(boxes),
                                             overlapThresh=0.3)

            copia = imagem_original.copy()
            for (x1, y1, x2, y2) in boxes:
                cv2.rectangle(copia, (x1, y1), (x2, y2), cor, 2)

            cv2.imshow("Final", copia[:, :, ::-1])
            cv2.waitKey(0)

    cv2.destroyAllWindows()
def train(working_dir, grid_size, learning_rate, batch_size, num_walks,
          model_type, fn):
    train_props, val_props, test_props = get_props(working_dir,
                                                   dtype=np.float32)
    means_stds = np.loadtxt(working_dir + "/means_stds.csv",
                            dtype=np.float32,
                            delimiter=',')

    # filter out redundant qm8 properties
    if train_props.shape[1] == 16:
        filtered_labels = list(range(0, 8)) + list(range(12, 16))
        train_props = train_props[:, filtered_labels]
        val_props = val_props[:, filtered_labels]
        test_props = test_props[:, filtered_labels]

        means_stds = means_stds[:, filtered_labels]
    if model_type == "resnet18":
        model = ResNet(BasicBlock, [2, 2, 2, 2],
                       grid_size,
                       "regression",
                       feat_nums,
                       e_sizes,
                       num_classes=train_props.shape[1])
    elif model_type == "resnet34":
        model = ResNet(BasicBlock, [3, 4, 6, 3],
                       grid_size,
                       "regression",
                       feat_nums,
                       e_sizes,
                       num_classes=train_props.shape[1])
    elif model_type == "resnet50":
        model = ResNet(Bottleneck, [3, 4, 6, 3],
                       grid_size,
                       "regression",
                       feat_nums,
                       e_sizes,
                       num_classes=train_props.shape[1])
    elif model_type == "densenet121":
        model = densenet121(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    elif model_type == "densenet161":
        model = densenet161(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    elif model_type == "densenet169":
        model = densenet169(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    elif model_type == "densenet201":
        model = densenet201(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    else:
        print("specify a valid model")
        return
    model.float()
    model.cuda()
    loss_function_train = nn.MSELoss(reduction='none')
    loss_function_val = nn.L1Loss(reduction='none')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # if model_type[0] == "r":
    # 	batch_size = 128
    # 	optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
    # 					   momentum=0.9, weight_decay=5e-4, nesterov=True)
    # elif model_type[0] == "d":
    # 	batch_size = 512
    # 	optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
    # 					   momentum=0.9, weight_decay=1e-4, nesterov=True)
    # else:
    # 	print("specify a vlid model")
    # 	return

    stds = means_stds[1, :]
    tl_list = []
    vl_list = []

    log_file = open(fn + "txt", "w")
    log_file.write("start")
    log_file.flush()

    for file_num in range(num_loads):
        if file_num % 20 == 0:
            model_file = open("../../scratch/" + fn + ".pkl", "wb")
            pickle.dump(model, model_file)
            model_file.close()

        log_file.write("load: " + str(file_num))
        print("load: " + str(file_num))
        # Get new random walks
        if file_num == 0:
            t = time.time()
            train_loader, val_loader, test_loader = get_loaders(working_dir, \
                        file_num, \
                        grid_size, \
                        batch_size, \
                        train_props, \
                        val_props=val_props, \
                        test_props=test_props)
            print("load time")
            print(time.time() - t)
        else:
            file_num = random.randint(0, num_walks - 1)
            t = time.time()
            train_loader, _, _ = get_loaders(working_dir, \
                   file_num, \
                   grid_size, \
                   batch_size, \
                   train_props)
            print("load time")
            print(time.time() - t)
        # Train on set of random walks, can do multiple epochs if desired
        for epoch in range(epochs_per_load):
            model.train()
            t = time.time()
            train_loss_list = []
            train_mae_loss_list = []
            for i, (walks_int, walks_float, props) in enumerate(train_loader):
                walks_int = walks_int.cuda()
                walks_int = walks_int.long()
                walks_float = walks_float.cuda()
                walks_float = walks_float.float()
                props = props.cuda()
                outputs = model(walks_int, walks_float)
                # Individual losses for each item
                loss_mae = torch.mean(loss_function_val(props, outputs), 0)
                train_mae_loss_list.append(loss_mae.cpu().detach().numpy())
                loss = torch.mean(loss_function_train(props, outputs), 0)
                train_loss_list.append(loss.cpu().detach().numpy())
                # Loss converted to single value for backpropagation
                loss = torch.sum(loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            model.eval()
            val_loss_list = []
            with torch.no_grad():
                for i, (walks_int, walks_float,
                        props) in enumerate(val_loader):
                    walks_int = walks_int.cuda()
                    walks_int = walks_int.long()
                    walks_float = walks_float.cuda()
                    walks_float = walks_float.float()
                    props = props.cuda()
                    outputs = model(walks_int, walks_float)
                    # Individual losses for each item
                    loss = loss_function_val(props, outputs)
                    val_loss_list.append(loss.cpu().detach().numpy())
            # ith row of this array is the losses for each label in batch i
            train_loss_arr = np.array(train_loss_list)
            train_mae_arr = np.array(train_mae_loss_list)
            log_file.write("training mse loss\n")
            log_file.write(str(np.mean(train_loss_arr)) + "\n")
            log_file.write("training mae loss\n")
            log_file.write(str(np.mean(train_mae_arr)) + "\n")
            print("training mse loss")
            print(str(np.mean(train_loss_arr)))
            print("training mae loss")
            print(str(np.mean(train_mae_arr)))
            val_loss_arr = np.concatenate(val_loss_list, 0)
            val_loss = np.mean(val_loss_arr, 0)
            log_file.write("val loss\n")
            log_file.write(str(np.mean(val_loss_arr)) + "\n")
            print("val loss")
            print(str(np.mean(val_loss_arr)))
            # Unnormalized loss is for comparison to papers
            tnl = np.mean(train_mae_arr, 0)
            log_file.write("train normalized losses\n")
            log_file.write(" ".join(list(map(str, tnl))) + "\n")
            print("train normalized losses")
            print(" ".join(list(map(str, tnl))))
            log_file.write("val normalized losses\n")
            log_file.write(" ".join(list(map(str, val_loss))) + "\n")
            print("val normalized losses")
            print(" ".join(list(map(str, val_loss))))
            tunl = stds * tnl
            log_file.write("train unnormalized losses\n")
            log_file.write(" ".join(list(map(str, tunl))) + "\n")
            print("train unnormalized losses")
            print(" ".join(list(map(str, tunl))))
            vunl = stds * val_loss
            log_file.write("val unnormalized losses\n")
            log_file.write(" ".join(list(map(str, vunl))) + "\n")
            log_file.write("\n")
            print("val unnormalized losses")
            print(" ".join(list(map(str, vunl))))
            print("\n")
            print("time")
            print(time.time() - t)
        file_num += 1
        log_file.flush()
    log_file.close()
    return model
def main():
    ### fix seed
    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)

    # table data load
    starttime = time.time()
    df_test = pd.read_csv("../input/sample_submission.csv")
    labels = df_test.columns[1:].tolist()
    df_test['path'] = "{}/".format(wav_dir) + df_test['fname']
    print("table data loading done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # get data length
    starttime = time.time()
    p = Pool(2)
    len_list = p.map(get_len, df_test['path'].values)
    df_test['length'] = len_list
    print("getting data length done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # data sort
    starttime = time.time()
    df_test_sort = df_test.copy()
    df_test_sort['index'] = np.arange(len(df_test_sort))
    df_test_sort = df_test_sort.sort_values(['length', 'index']).reset_index(drop=True)
    print("data sort done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # batch splitting
    starttime = time.time()
    NUM_BATCH_LIMIT = 60 + int(len(df_test_sort) * NUMBATCH_PER_NUMDATA)
    print("num batch limit: {}".format(NUM_BATCH_LIMIT))
    patience_rate = 0
    patience_rate_tmp = 0
    num_batch, count = get_num_batch(df_test_sort, patience_rate)
    print("patience_rate_tmp: {:.2f}, patience_rate_tmp: {:.2f}, num_batch: {:3d}".format(
        patience_rate, patience_rate_tmp, num_batch))
    while num_batch > NUM_BATCH_LIMIT and patience_rate_tmp < MAX_PATIENCE:
        patience_rate_tmp += 0.01
        num_batch_tmp, count_tmp = get_num_batch(df_test_sort, patience_rate_tmp)
        if num_batch_tmp < num_batch:
            num_batch = num_batch_tmp
            count = count_tmp
            patience_rate = patience_rate_tmp
        print("patience_rate_tmp: {:.2f}, patience_rate_tmp: {:.2f}, num_batch_tmp: {:3d}".format(
            patience_rate, patience_rate_tmp, num_batch_tmp))
    num_batch, count = get_num_batch(df_test_sort, patience_rate)
    print("num batch: {}, rate of padding patience: {:.2f}".format(num_batch, patience_rate))
    print("batch splitting done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # store batch id
    starttime = time.time()
    batch_list = []
    for i in range(num_batch):
        batch_list += [i] * count[i][1]
    df_test_sort['batch'] = batch_list
    print(df_test_sort[['path', 'length', 'batch']].head())
    print("save batch id done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # split dataframe if too big
    starttime = time.time()
    df_mel_split = get_df_split(df_test_sort, LEN_DF_MEL_LIMIT)
    print("df_mel_split")
    for i in range(len(df_mel_split)):
        print("{}: num data: {}, total length: {}".format(i + 1, len(df_mel_split[i]), df_mel_split[i]['length'].sum()))
    print("dataframe splitting done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # ### EnvNet part
    # build model
    model = EnvNetv2(NUM_CLASS).cuda()
    model.eval()

    # split df for EnvNet
    df_wav_split = get_df_split(df_test_sort, LEN_DF_WAV_LIMIT)
    print("df_wav_split")
    for i in range(len(df_wav_split)):
        print("{}: num data: {}, total length: {}".format(i + 1, len(df_wav_split[i]), df_wav_split[i]['length'].sum()))

    print("predict wav...")

    # parallel threading
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
    threadA = executor.submit(get_mel_batch, df_mel_split[0])
    threadB = executor.submit(predict_wav_split, model, df_wav_split[0], ENV_LIST)
    preds_wav_split = []
    preds_wav_split.append(threadB.result())
    executor.shutdown()
    print("parallel threading done.", time.time() - starttime, time.time() - starttime0)

    # do remain EnvNet prediction
    if len(df_wav_split) > 1:
        for split in range(1, len(df_wav_split)):
            preds_wav_split.append(predict_wav_split(model, df_wav_split[split], ENV_LIST))
            print("envnet prediction split {}/{}, done. {:.1f}/{:.1f}".format(
                split + 1, len(df_wav_split), time.time() - starttime, time.time() - starttime0))
    preds_test_wav = np.concatenate(preds_wav_split, axis=4)
    print("all envnet predict done.", time.time() - starttime, time.time() - starttime0)

    # build model
    starttime = time.time()
    model = ResNet(NUM_CLASS).cuda()
    model.eval()
    print("building ResNet model done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # predict split #1
    preds_test_mel = []
    preds_test_mel.append(predict_mel_split(model, df_mel_split[0], RES_LIST))
    shutil.rmtree(BATCH_DIR)
    print("mel prediction of split {} done. {:.1f}/{:.1f}".format(1, time.time() - starttime, time.time() - starttime0))

    # process remain split
    if len(df_mel_split) > 1:
        for split in range(1, len(df_mel_split)):
            # mel preprocessing
            starttime = time.time()
            df_test_sort_tmp = df_mel_split[split]
            get_mel_batch(df_test_sort_tmp)
            print("mel preprocessing of split {} done. {:.1f}/{:.1f}".format(
                split + 1, time.time() - starttime, time.time() - starttime0))
            preds_test_mel.append(predict_mel_split(model, df_test_sort_tmp, RES_LIST))
            shutil.rmtree(BATCH_DIR)
            print("mel prediction of split {} done. {:.1f}/{:.1f}".format(
                split + 1, time.time() - starttime, time.time() - starttime0))

    print("all prediction done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))

    # concat
    starttime = time.time()
    preds_test_mel = np.concatenate(preds_test_mel, axis=4)
    print("preds_test_mel.shape", preds_test_mel.shape)
    print("concat done.", time.time() - starttime, time.time() - starttime0)

    # make submission
    preds_test_avr = (
            + preds_test_mel[:, 0, :len(RES_LIST[0]['epoch'])].mean(axis=(0, 1, 2)) * 4 / 13
            + preds_test_mel[:, 1, :len(RES_LIST[1]['epoch'])].mean(axis=(0, 1, 2)) * 3 / 13
            + preds_test_mel[:, 2, :len(RES_LIST[2]['epoch'])].mean(axis=(0, 1, 2)) * 3 / 13
            + preds_test_wav[:, 0, :len(ENV_LIST[0]['epoch'])].mean(axis=(0, 1, 2)) * 1 / 13
            + preds_test_wav[:, 1, :len(ENV_LIST[1]['epoch'])].mean(axis=(0, 1, 2)) * 1 / 13
            + preds_test_wav[:, 2, :len(ENV_LIST[2]['epoch'])].mean(axis=(0, 1, 2)) * 1 / 13)
    print(preds_test_mel.shape, preds_test_wav.shape)
    print(preds_test_avr.shape)
    df_test_sort = df_test_sort.sort_values(['length', 'index']).reset_index(drop=True)
    df_test_sort[labels] = preds_test_avr
    df_test_sort = df_test_sort.sort_values('index').reset_index(drop=True)
    df_test_sort[['fname'] + labels].to_csv("../output/submission1.csv", index=None)
    print("save submission done. {:.1f}/{:.1f}".format(time.time() - starttime, time.time() - starttime0))
예제 #5
0
    model = ResNet(depth=50,
                   pretrained=False,
                   cut_at_pooling=False,
                   num_features=num_features,
                   norm=False,
                   dropout=0.5,
                   num_classes=datareader.num_class)
    model.load_state_dict(
        torch.load(osp.join(args.model_dir,
                            'best_triplet_%d.pth' % (hash_bit))))
    model.cuda()
''' ------------------------------- Testing --------------------------------- '''
if args.test:
    batch_size = args.triplet_batch_size
    ''' ============================= Testing ================================ '''
    model.eval()
    n_feat = hash_bit
    ''' Testing Query Features '''
    if args.dataset == 'cuhk03':
        prbX, galX, prbY, galY = datareader.read_pair_images(
            'test', need_augmentation=False)
    elif args.dataset == 'market1501':
        prbX, prbY, _ = datareader.read_images('query',
                                               need_augmentation=False,
                                               need_shuffle=False)
        galX, galY, _ = datareader.read_images('test',
                                               need_augmentation=False,
                                               include_distractors=True,
                                               need_shuffle=False)
    elif args.dataset == 'dukemtmc-reid':
        prbX, prbY, _ = datareader.read_images('query',
예제 #6
0
def main():
    if not sys.warnoptions:
        warnings.simplefilter("ignore")

    # --- hyper parameters --- #
    BATCH_SIZE = 256
    LR = 1e-3
    WEIGHT_DECAY = 1e-4
    N_layer = 18
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # --- data process --- #
    # info
    src_path = './data/'
    target_path = './saved/ResNet18/'
    model_path = target_path + 'pkls/'
    pred_path = target_path + 'preds/'

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    if not os.path.exists(pred_path):
        os.makedirs(pred_path)

    # evaluation: num of classify labels & image size
    # output testing id csv
    label2num_dict, num2label_dict = data_evaluation(src_path)

    # load
    train_data = dataLoader(src_path, 'train', label2num_dict)
    train_len = len(train_data)
    test_data = dataLoader(src_path, 'test')

    train_loader = Data.DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=12,
    )
    test_loader = Data.DataLoader(
        dataset=test_data,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=12,
    )

    # --- model training --- #
    # fp: for storing data
    fp_train_acc = open(target_path + 'train_acc.txt', 'w')
    fp_time = open(target_path + 'time.txt', 'w')

    # train
    highest_acc, train_acc_seq = 0, []
    loss_funct = nn.CrossEntropyLoss()
    net = ResNet(N_layer).to(device)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=LR,
                                 weight_decay=WEIGHT_DECAY)
    print(net)

    for epoch_i in count(1):
        right_count = 0

        # print('\nTraining epoch {}...'.format(epoch_i))
        # for batch_x, batch_y in tqdm(train_loader):
        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            # clear gradient
            optimizer.zero_grad()

            # forward & backward
            output = net.forward(batch_x.float())
            highest_out = torch.max(output, 1)[1]
            right_count += sum(batch_y == highest_out).item()

            loss = loss_funct(output, batch_y)
            loss.backward()

            # update parameters
            optimizer.step()

        # calculate accuracy
        train_acc = right_count / train_len
        train_acc_seq.append(train_acc * 100)

        if train_acc > highest_acc:
            highest_acc = train_acc

        # save model
        torch.save(
            net.state_dict(),
            '{}{}_{}_{}.pkl'.format(model_path,
                                    target_path.split('/')[2],
                                    round(train_acc * 1000), epoch_i))

        # write data
        fp_train_acc.write(str(train_acc * 100) + '\n')
        fp_time.write(
            str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) + '\n')
        print('\n{} Epoch {}, Training accuracy: {}'.format(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), epoch_i,
            train_acc))

        # test
        net.eval()
        test_df = pd.read_csv(src_path + 'testing_data/testing_labels.csv')
        with torch.no_grad():
            for i, (batch_x, _) in enumerate(test_loader):
                batch_x = batch_x.to(device)
                output = net.forward(batch_x.float())
                highest_out = torch.max(output, 1)[1].cpu()
                labels = [
                    num2label_dict[out_j.item()] for out_j in highest_out
                ]
                test_df['label'].iloc[i * BATCH_SIZE:(i + 1) *
                                      BATCH_SIZE] = labels
        test_df.to_csv('{}{}_{}_{}.csv'.format(pred_path,
                                               target_path.split('/')[2],
                                               round(train_acc * 1000),
                                               epoch_i),
                       index=False)
        net.train()

        lr_decay(optimizer)

    fp_train_acc.close()
    fp_time.close()
예제 #7
0
def initiate_cifar10(random_model=False):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    trainset = datasets.CIFAR10(root='./data',
                                train=True,
                                download=True,
                                transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=128,
                                               shuffle=True,
                                               **kwargs)

    testset = datasets.CIFAR10(root='./data',
                               train=False,
                               download=True,
                               transform=transform)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=200,
                                              shuffle=False,
                                              **kwargs)

    classes = [
        'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
        'truck'
    ]

    def show_images(images, labels):
        num_img = len(images)
        np_images = [img.numpy() for img in images]
        fig, axes = plt.subplots(nrows=1, ncols=num_img, figsize=(20, 45))

        for i, ax in enumerate(axes.flat):
            ax.set_axis_off()
            im = ax.imshow(np_images[i], vmin=0., vmax=1.)
            ax.set_title(f'{labels[i]}')
            plt.axis("off")

        fig.subplots_adjust(bottom=0.1,
                            top=0.9,
                            left=0.1,
                            right=0.8,
                            wspace=0.1,
                            hspace=0.25)

        plt.show()

    images, labels = iter(train_loader).next()
    num_img_to_plot = 9
    images = [images[i].permute(1, 2, 0) for i in range(num_img_to_plot)]
    labels = [classes[i] for i in labels[:num_img_to_plot]]
    # show_images(images, labels)

    model = ResNet().to(device)
    model_2 = ResNet().to(device)
    if not random_model:
        if not use_cuda:
            model.load_state_dict(
                torch.load("checkpoints/cifar/resnet_NT_ep_100.pt",
                           map_location='cpu'))
            model_2.load_state_dict(
                torch.load(
                    "checkpoints/cifar/resnet_RFGSM_eps_8_a_10_ep_100.pt",
                    map_location='cpu'))
        else:
            model.load_state_dict(
                torch.load("checkpoints/cifar/resnet_NT_ep_100.pt"))
            model_2.load_state_dict(
                torch.load(
                    "checkpoints/cifar/resnet_RFGSM_eps_8_a_10_ep_100.pt"))
    model.eval()
    model_2.eval()
    test_loss, test_acc = test(model, test_loader)
    print(f'Clean \t loss: {test_loss:.4f} \t acc: {test_acc:.4f}')

    return model, model_2, train_loader, test_loader
def train(working_dir, grid_size, learning_rate, batch_size, num_cores):
    process = psutil.Process(os.getpid())
    print(process.memory_info().rss / 1024 / 1024 / 1024)
    train_feat_dict = get_feat_dict(working_dir + "/train_smiles.csv")
    val_feat_dict = get_feat_dict(working_dir + "/val_smiles.csv")
    test_feat_dict = get_feat_dict(working_dir + "/test_smiles.csv")
    # There are about 0.08 gb
    process = psutil.Process(os.getpid())
    print("pre model")
    print(process.memory_info().rss / 1024 / 1024 / 1024)

    torch.set_default_dtype(torch.float64)
    train_props, val_props, test_props = get_props(working_dir, dtype=int)
    print("pre model post props")
    print(process.memory_info().rss / 1024 / 1024 / 1024)
    model = ResNet(BasicBlock, [2, 2, 2, 2],
                   grid_size,
                   "classification",
                   feat_nums,
                   e_sizes,
                   num_classes=train_props.shape[1])
    model.float()
    model.cuda()
    print("model params")
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(pytorch_total_params)
    model.cpu()
    print("model")
    print(process.memory_info().rss / 1024 / 1024 / 1024)
    loss_function = masked_cross_entropy
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    tl_list = []
    vl_list = []
    tmra_list = []
    vmra_list = []

    for file_num in range(num_loads):
        # Get new random walks
        if file_num == 0:
            print("before get_loaders")
            process = psutil.Process(os.getpid())
            print(process.memory_info().rss / 1024 / 1024 / 1024)
            train_loader, val_loader, test_loader = get_loaders(num_cores, \
                     working_dir, \
                     file_num, \
                     grid_size, \
                     batch_size, \
                     train_props, \
                     train_feat_dict, \
                     val_props=val_props, \
                     val_feat_dict=val_feat_dict, \
                     test_props=test_props, \
                     test_feat_dict=test_feat_dict)
        else:
            print("before get_loaders 2")
            process = psutil.Process(os.getpid())
            print(process.memory_info().rss / 1024 / 1024 / 1024)
            train_loader, _, _ = get_loaders(num_cores, \
                   working_dir, \
                   file_num, \
                   grid_size, \
                   batch_size, \
                   train_props, \
                   train_feat_dict)
        # Train on a single set of random walks, can do multiple epochs if desired
        for epoch in range(epochs_per_load):
            model.train()
            model.cuda()
            t = time.time()
            train_loss_list = []
            props_list = []
            outputs_list = []
            # change
            for i, (walks_int, walks_float, props) in enumerate(train_loader):
                walks_int = walks_int.cuda()
                walks_int = walks_int.long()
                walks_float = walks_float.cuda()
                walks_float = walks_float.float()
                props = props.cuda()
                props = props.long()
                props_list.append(props)
                outputs = model(walks_int, walks_float)
                outputs_list.append(outputs)
                loss = loss_function(props, outputs)
                train_loss_list.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            props = torch.cat(props_list, 0)
            props = props.cpu().numpy()
            outputs = torch.cat(outputs_list, 0)
            outputs = outputs.detach().cpu().numpy()
            # Get train rocauc value
            train_rocaucs = []
            for i in range(props.shape[1]):
                mask = props[:, i] != 2
                train_rocauc = roc_auc_score(props[mask, i], outputs[mask, i])
                train_rocaucs.append(train_rocauc)
            model.eval()
            with torch.no_grad():
                ds = val_loader.dataset
                walks_int = ds.int_feat_tensor
                walks_float = ds.float_feat_tensor
                props = ds.prop_tensor
                walks_int = walks_int.cuda()
                walks_int = walks_int.long()
                walks_float = walks_float.cuda()
                walks_float = walks_float.float()
                props = props.cuda()
                outputs = model(walks_int, walks_float)
                loss = loss_function(props, outputs)
                props = props.cpu().numpy()
                outputs = outputs.cpu().numpy()
                val_rocaucs = []
                for i in range(props.shape[1]):
                    mask = props[:, i] != 2
                    val_rocauc = roc_auc_score(props[mask, i], outputs[mask,
                                                                       i])
                    val_rocaucs.append(val_rocauc)
            print("load: " + str(file_num) + ", epochs: " + str(epoch))
            print("training loss")
            # Slightly approximate since last batch can be smaller...
            tl = statistics.mean(train_loss_list)
            print(tl)
            print("val loss")
            vl = loss.item()
            print(vl)
            print("train mean roc auc")
            tmra = sum(train_rocaucs) / len(train_rocaucs)
            print(tmra)
            print("val mean roc auc")
            vmra = sum(val_rocaucs) / len(val_rocaucs)
            print(vmra)
            print("time")
            print(time.time() - t)
            tl_list.append(tl)
            vl_list.append(vl)
            tmra_list.append(tmra)
            vmra_list.append(vmra)
            model.cpu()
        file_num += 1
        del train_loader
    save_plot(tl_list, vl_list, tmra_list, vmra_list)
    return model