def main(): ## Network initialize ## net = Network(classes=2, arch=args.arch) # defalt number of classes 2 #net.load_state_dict(torch.load('./model/cs_globalmean/model_10.pth')) #print('Load model successfully') ## define loss function (criterion) and optimizer ## criterion = nn.CosineEmbeddingLoss().cuda() optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay = 1e-4) ## Data loading ## traindir = os.path.join(args.train_data, 'train') valpdir = os.path.join(args.test_data, 'pocket') valldir = os.path.join(args.test_data, 'ligand') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) num_classes = len([name for name in os.listdir(traindir)]) - 1 print("num_classes = '{}'".format(num_classes)) train_data = datasets.ImageFolder( ## train/tdata, fdata traindir, transforms.Compose([ transforms.ToTensor(), ## (height x width, channel),(0-255) -> (channel x height x width),(0.0-1.0) normalize, ## GRB の正規化 ])) train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) val_pdata = datasets.ImageFolder( ## val/pocket valpdir, transforms.Compose([ transforms.ToTensor(), normalize, ])) val_ploader = torch.utils.data.DataLoader(dataset=val_pdata, batch_size=20, # batch-size for test shuffle=False, num_workers=args.workers) val_ldata = datasets.ImageFolder( ## val/ligand valldir, transforms.Compose([ transforms.ToTensor(), normalize, ])) val_lloader = torch.utils.data.DataLoader(dataset=val_ldata, batch_size=20, # batch-size for test shuffle=False, num_workers=args.workers) ## Train ## print(('Start training: lr %f, batch size %d, classes %d'%(args.lr, args.batch_size, num_classes))) steps = args.start_epoch iter_per_epoch = args.batch_size//40 imgs = [] lbls = [] image_list = [] label_list = [] for i, (images, labels) in enumerate(train_loader): imgs.append(images) lbls.append(labels) shuffle_list = [i*40 for i in range(iter_per_epoch*len(imgs))] random.shuffle(shuffle_list) list_length = iter_per_epoch*len(imgs) for i in range(list_length): s = shuffle_list[i]//args.batch_size f = shuffle_list[i]%args.batch_size image_list.append(imgs[s][f:f+40]) label_list.append(lbls[s][f:f+40]) init_numdict = {} numlist = [] for i in range(list_length): if label_list[i][0].tolist()==0: numlist.append(i) for i in range(int(list_length/2)): init_numdict[i] = numlist[i] for epoch in range(args.start_epoch, args.epochs): #if (epoch+1)%(int(list_length/2)-1)==0 and epoch>args.start_epoch: if epoch%1==0 and epoch>19: path = modelpath + 'model_' + str(epoch) + '.pth' torch.save(net.state_dict(), path) print('>>>>>Save model successfully<<<<<') loss = 0 sum_loss = 0 if (epoch+1)%(int(list_length/2)-1)==0 and epoch>0: image_list, init_numdict = shuffle_fpair(image_list, label_list, list_length, init_numdict, 1) print('Shuffle mode >>>> 1') else: image_list, init_numdict = shuffle_fpair(image_list, label_list, list_length, init_numdict, 0) print('Shuffle mode >>>> 0') image_list, label_list, init_numdict = shuffle_set(image_list, label_list, list_length, init_numdict) for i , (images, lables) in enumerate(zip(image_list, label_list)): images = Variable(image_list[i]) labels = Variable(label_list[i]) # Forward + Backward + Optimize label = torch.tensor([labels[0]]) label = label*2-1 optimizer.zero_grad() output_lig, output_poc = net(images, 'train', 'max', 'max') sim = cos(output_lig, output_poc) loss += criterion(output_lig, output_poc, label.type_as(output_lig)) if (i+1)%(iter_per_epoch)==0 and i>0: loss /= iter_per_epoch ## calculate loss average sum_loss += loss print('Epoch: %2d, iter: %2d, Loss: %.4f' %(epoch, i+1, loss)) if (i+1)==list_length: print('>>>Epoch: %2d, Train_Loss: %.4f' %(epoch, sum_loss/list_length*iter_per_epoch)) sum_loss = 0 #test(net, val_ploader, val_lloader, epoch) loss.backward() optimizer.step() loss = 0
shuffle=False, num_workers=cfg.TRAIN.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, ) testLoader = DataLoader( test_set, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=cfg.TEST.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, ) if cfg.METHOD == "tau_norm": model_state_dict = model.state_dict() # set bias as zero model_state_dict['module.classifier.bias'].copy_(torch.zeros( (num_classes))) weight_ori = model_state_dict['module.classifier.weight'] norm_weight = torch.norm(weight_ori, 2, 1) best_accuracy = 0 best_p = 0 for p in np.arange(0.0, 1.0, 0.1): ws = weight_ori.clone() for i in range(weight_ori.size(0)): ws[i] = ws[i] / torch.pow(norm_weight[i], p) model_state_dict['module.classifier.weight'].copy_(ws) print("\n___________________________", p, "__________________________________") acc, _ = valid_model(testLoader, model, num_classes, para_dict_train, para_dict_test,criterion, LOSS_RATIO=0)