def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') config = [] if args.arch == "Unet": for block in range(args.NUM_DOWN_CONV): out_channels = (2**block) * args.HIDDEN_DIM if (block == 0): config += [( 'conv2d', [out_channels, args.imgc, 3, 3, 1, 1] ) # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias ] else: config += [ ('conv2d', [out_channels, out_channels // 2, 3, 3, 1, 1]), # out_c, in_c, k_h, k_w, stride, padding ] config += [ ('leakyrelu', [0.2, False]), # alpha; if true then executes relu in place ('bn', [out_channels]) ] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('max_pool2d', [2, 2, 0])] # kernel_size, stride, padding for block in range(args.NUM_DOWN_CONV - 1): out_channels = (2**(args.NUM_DOWN_CONV - block - 2)) * args.HIDDEN_DIM in_channels = out_channels * 3 config += [('upsample', [2])] config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [ ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1]) ] # all the conv2d before are without bias, and this conv_b is with bias else: raise ("architectures other than Unet hasn't been added!!") maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) # print(maml) for name, param in maml.named_parameters(): print(name, param.size()) print('Total trainable tensors:', num) SUMMARY_INTERVAL = 5 TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5 ITER_SAVE_INTERVAL = 300 EPOCH_SAVE_INTERVAL = 5 model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str( args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str( args.meta_lr) + "_inner_lr_" + str( args.update_lr) + "_num_inner_updates_" + str(args.update_step) if not os.path.isdir(model_path): os.mkdir(model_path) start_epoch = 0 if (args.continue_train): print("Restoring weights from ", model_path + "/epoch_" + str(args.continue_epoch) + ".pt") checkpoint = torch.load(model_path + "/epoch_" + str(args.continue_epoch) + ".pt") maml = checkpoint['model'] maml.lr_scheduler = checkpoint['lr_scheduler'] maml.meta_optim = checkpoint['optimizer'] start_epoch = args.continue_epoch db = RCWA_data_loader(batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz, data_folder=args.data_folder) for step in range(start_epoch, args.epoch): print("epoch: ", step) if step % EPOCH_SAVE_INTERVAL == 0: checkpoint = { 'epoch': step, 'model': maml, 'optimizer': maml.meta_optim, 'lr_scheduler': maml.lr_scheduler } torch.save(checkpoint, model_path + "/epoch_" + str(step) + ".pt") for itr in range( int(0.7 * db.total_data_samples / ((args.k_spt + args.k_qry) * args.task_num))): x_spt, y_spt, x_qry, y_qry = db.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry) if itr % SUMMARY_INTERVAL == 0: print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f" % ( itr, accs[0], accs[-1], loss_q) print(print_str) if itr % TEST_PRINT_INTERVAL == 0: accs = [] for _ in range(10): # test x_spt, y_spt, x_qry, y_qry = db.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print( 'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f' % (accs[0], accs[-1])) maml.lr_scheduler.step()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train', default='train', help='train file') parser.add_argument('--val', default='val', help='val file') parser.add_argument('--n_way', type=int, help='n way', default=5) parser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) parser.add_argument('--k_qry', type=int, help='k shot for query set', default=5) parser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32) parser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.001) parser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01) parser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5) parser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10) parser.add_argument('--max_length', default=64, type=int, help='max length') parser.add_argument('--epoch', type=int, help='epoch number', default=4) parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)') parser.add_argument('--embedding', default='bert', type=str, help='"glove" or "bert".') parser.add_argument('--gpu', default="1,2", type=str, help='gpu use.') parser.add_argument('--type', default="cnnLinear", type=str, help="type of the net, 'cnnLinear' 'concatLinear' or 'clsLinear'.") parser.add_argument('--filename', default=None, type=str, help="type of the net, 'cnnLinear' 'concatLinear' or 'clsLinear'.") parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') args = parser.parse_args() # print(str(args)) logging.info(str(args)) if args.filename == None: file_name = 'log/{}way{}shot-{}-{}'.format(args.n_way, args.k_spt, args.embedding, args.type) dt = datetime.now() file_name += dt.strftime('%Y-%m-%d-%H:%M:%S-%f') file_name += ".log" else: file_name = os.path.join('log', args.filename) with open(file_name, 'w') as f: f.writelines(str(args).split(',')) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random.seed(2020) np.random.seed(2020) torch.manual_seed(2020) n_gpu = torch.cuda.device_count() if n_gpu > 0: torch.cuda.manual_seed_all(2020) tokenizer = Berttokenizer(max_length=args.max_length) train_data_loader = bert_getloader(args.train, tokenizer, N=args.n_way, K=args.k_spt, Q=args.k_qry, na_rate=args.na_rate, batch_size=args.task_num) val_data_loader = bert_getloader(args.val, tokenizer, N=args.n_way, K=args.k_spt, Q=1, batch_size=20) maml = Meta(args,device,n_gpu) for name,parm in maml.named_parameters(): if(parm.requires_grad): print(name,parm.shape) # maml.to(device) # print(maml.named_parameters()) logging.info(n_gpu) accses_train = [] accses_test = [] losses = [] best_result = 0 start = time.time() maml.to(device) # if torch.cuda.is_available(): # # # maml = nn.DataParallel(maml) # # maml = maml.cuda() for epoch in range(args.epoch): for step,batch in enumerate(train_data_loader): if n_gpu >= 1: batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self x_spt, y_spt, x_qry, y_qry = batch accs, loss = maml(x_spt, y_spt, x_qry, y_qry) losses.append(loss) accses_train.append(accs) if step % 10 == 0: logging.info("step: %s training acc:%s loss:%s cost%smin"%(step,accs,loss, (time.time() - start) // 60)) with open(file_name, 'a') as f: f.write("\nstep: {}\ttraining acc:{}\tloss:{}\tcost:{}min".format(step, accs, loss, (time.time() - start) // 60)) if step % 100 == 0 and step!=0: l = [] for _ in range(10): accs = [] x_spt, y_spt, x_qry, y_qry = next(val_data_loader) x_spt = x_spt.to(device) x_qry = x_qry.to(device) y_spt = y_spt.to(device) # y_qry = y_qry.cuda() for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): # [N,K,MAXLEN] pred = maml.evaluate(x_spt_one, y_spt_one, x_qry_one).cpu().numpy() acc = (y_qry_one.numpy() == pred).mean() accs.append(acc) accs = np.array(accs).mean(axis=0).astype(np.float16) l.append(accs) with open(file_name, 'a') as f: f.write("\nTest acc:{}\tmean:{}\tcost:{}min".format(l, str(np.array(l).mean()),(time.time() - start) // 60)) # logging.info('Test acc:', l, '\tmean:', np.array(l).mean(), '\tcost,', np.array(l).mean(), 'min\n') logging.info("Test acc:%s mean:%s cost:%smin"%(l,np.array(l).mean(),np.array(l).mean())) # print('Test acc:', l, '\tmean:', np.array(l).mean(), '\tcost,', (time.time() - start) // 60, 'min\n') if best_result <= np.array(l).mean(): torch.save(maml, "{}best.ckpt".format(file_name)) accses_test.append([step, accs])